Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
0980b27f
Unverified
Commit
0980b27f
authored
Jan 02, 2026
by
Yan Ru Pei
Committed by
GitHub
Jan 02, 2026
Browse files
feat: mockers with bootstrap optimization (sglang testing) + CI test (#5121)
Signed-off-by:
PeaBrane
<
yanrpei@gmail.com
>
parent
cd8dddee
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
895 additions
and
295 deletions
+895
-295
components/src/dynamo/mocker/args.py
components/src/dynamo/mocker/args.py
+31
-3
components/src/dynamo/mocker/main.py
components/src/dynamo/mocker/main.py
+35
-1
components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py
...src/dynamo/sglang/request_handlers/llm/prefill_handler.py
+0
-15
lib/bindings/c/src/lib.rs
lib/bindings/c/src/lib.rs
+6
-0
lib/llm/src/kv_router.rs
lib/llm/src/kv_router.rs
+66
-63
lib/llm/src/kv_router/prefill_router.rs
lib/llm/src/kv_router/prefill_router.rs
+56
-119
lib/llm/src/kv_router/subscriber.rs
lib/llm/src/kv_router/subscriber.rs
+90
-42
lib/llm/src/local_model.rs
lib/llm/src/local_model.rs
+18
-1
lib/llm/src/mocker.rs
lib/llm/src/mocker.rs
+1
-0
lib/llm/src/mocker/bootstrap.rs
lib/llm/src/mocker/bootstrap.rs
+369
-0
lib/llm/src/mocker/engine.rs
lib/llm/src/mocker/engine.rs
+38
-0
lib/llm/src/mocker/protocols.rs
lib/llm/src/mocker/protocols.rs
+13
-0
lib/llm/src/preprocessor.rs
lib/llm/src/preprocessor.rs
+1
-0
lib/llm/src/protocols/common/preprocessor.rs
lib/llm/src/protocols/common/preprocessor.rs
+8
-0
lib/llm/src/protocols/common/timing.rs
lib/llm/src/protocols/common/timing.rs
+31
-6
lib/llm/src/protocols/openai/nvext.rs
lib/llm/src/protocols/openai/nvext.rs
+12
-0
lib/runtime/src/component/endpoint.rs
lib/runtime/src/component/endpoint.rs
+9
-4
lib/runtime/src/pipeline/network/ingress/shared_tcp_endpoint.rs
...ntime/src/pipeline/network/ingress/shared_tcp_endpoint.rs
+17
-3
lib/runtime/src/pipeline/network/manager.rs
lib/runtime/src/pipeline/network/manager.rs
+40
-22
tests/router/test_router_e2e_with_mockers.py
tests/router/test_router_e2e_with_mockers.py
+54
-16
No files found.
components/src/dynamo/mocker/args.py
View file @
0980b27f
...
@@ -108,12 +108,15 @@ def create_temp_engine_args_file(args) -> Path:
...
@@ -108,12 +108,15 @@ def create_temp_engine_args_file(args) -> Path:
"speedup_ratio"
:
getattr
(
args
,
"speedup_ratio"
,
None
),
"speedup_ratio"
:
getattr
(
args
,
"speedup_ratio"
,
None
),
"dp_size"
:
getattr
(
args
,
"dp_size"
,
None
),
"dp_size"
:
getattr
(
args
,
"dp_size"
,
None
),
"startup_time"
:
getattr
(
args
,
"startup_time"
,
None
),
"startup_time"
:
getattr
(
args
,
"startup_time"
,
None
),
"planner_profile_data"
:
str
(
getattr
(
args
,
"planner_profile_data"
,
None
))
"planner_profile_data"
:
(
if
getattr
(
args
,
"planner_profile_data"
,
None
)
str
(
getattr
(
args
,
"planner_profile_data"
,
None
))
else
None
,
if
getattr
(
args
,
"planner_profile_data"
,
None
)
else
None
),
"is_prefill"
:
getattr
(
args
,
"is_prefill_worker"
,
None
),
"is_prefill"
:
getattr
(
args
,
"is_prefill_worker"
,
None
),
"is_decode"
:
getattr
(
args
,
"is_decode_worker"
,
None
),
"is_decode"
:
getattr
(
args
,
"is_decode_worker"
,
None
),
"enable_local_indexer"
:
getattr
(
args
,
"enable_local_indexer"
,
None
),
"enable_local_indexer"
:
getattr
(
args
,
"enable_local_indexer"
,
None
),
# Note: bootstrap_port is NOT included here - it's set per-worker in launch_workers()
}
}
# Remove None values to only include explicitly set arguments
# Remove None values to only include explicitly set arguments
...
@@ -142,6 +145,13 @@ def validate_worker_type_args(args):
...
@@ -142,6 +145,13 @@ def validate_worker_type_args(args):
)
)
def
parse_bootstrap_ports
(
ports_str
:
str
|
None
)
->
list
[
int
]:
"""Parse comma-separated bootstrap ports string into list of integers."""
if
not
ports_str
:
return
[]
return
[
int
(
p
.
strip
())
for
p
in
ports_str
.
split
(
","
)]
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
"Mocker engine for testing Dynamo LLM infrastructure with vLLM-style CLI."
,
description
=
"Mocker engine for testing Dynamo LLM infrastructure with vLLM-style CLI."
,
...
@@ -291,6 +301,15 @@ def parse_args():
...
@@ -291,6 +301,15 @@ def parse_args():
default
=
False
,
default
=
False
,
help
=
"Enable worker-local KV indexer for tracking this worker's own KV cache state (default: False)"
,
help
=
"Enable worker-local KV indexer for tracking this worker's own KV cache state (default: False)"
,
)
)
parser
.
add_argument
(
"--bootstrap-ports"
,
type
=
str
,
default
=
None
,
help
=
"Comma-separated list of bootstrap ports for disaggregated serving rendezvous. "
"One port per worker (must match --num-workers). "
"Prefill workers listen on these ports; decode workers connect to them. "
"If not specified, bootstrap rendezvous is disabled."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--store-kv"
,
"--store-kv"
,
type
=
str
,
type
=
str
,
...
@@ -313,6 +332,15 @@ def parse_args():
...
@@ -313,6 +332,15 @@ def parse_args():
if
args
.
num_workers
<
1
:
if
args
.
num_workers
<
1
:
raise
ValueError
(
f
"--num-workers must be at least 1, got
{
args
.
num_workers
}
"
)
raise
ValueError
(
f
"--num-workers must be at least 1, got
{
args
.
num_workers
}
"
)
# Parse and validate bootstrap_ports
args
.
bootstrap_ports_list
=
parse_bootstrap_ports
(
args
.
bootstrap_ports
)
if
args
.
bootstrap_ports_list
:
if
len
(
args
.
bootstrap_ports_list
)
!=
args
.
num_workers
:
raise
ValueError
(
f
"--bootstrap-ports must have exactly --num-workers (
{
args
.
num_workers
}
) ports, "
f
"got
{
len
(
args
.
bootstrap_ports_list
)
}
:
{
args
.
bootstrap_ports_list
}
"
)
# Set endpoint default based on worker type if not explicitly provided
# Set endpoint default based on worker type if not explicitly provided
if
args
.
endpoint
is
None
:
if
args
.
endpoint
is
None
:
if
args
.
is_prefill_worker
:
if
args
.
is_prefill_worker
:
...
...
components/src/dynamo/mocker/main.py
View file @
0980b27f
...
@@ -5,9 +5,12 @@
...
@@ -5,9 +5,12 @@
# Now supports vLLM-style individual arguments for MockEngineArgs
# Now supports vLLM-style individual arguments for MockEngineArgs
import
asyncio
import
asyncio
import
json
import
logging
import
logging
import
os
import
os
import
signal
import
signal
import
tempfile
from
pathlib
import
Path
import
uvloop
import
uvloop
...
@@ -85,6 +88,13 @@ async def launch_workers(args, extra_engine_args_path):
...
@@ -85,6 +88,13 @@ async def launch_workers(args, extra_engine_args_path):
loop
=
asyncio
.
get_running_loop
()
loop
=
asyncio
.
get_running_loop
()
futures
=
[]
futures
=
[]
runtimes
=
[]
runtimes
=
[]
per_worker_temp_files
:
list
[
Path
]
=
[]
# Load base engine args if we need to create per-worker files with bootstrap_port
base_engine_args
=
None
if
args
.
bootstrap_ports_list
:
with
open
(
extra_engine_args_path
)
as
f
:
base_engine_args
=
json
.
load
(
f
)
for
worker_id
in
range
(
args
.
num_workers
):
for
worker_id
in
range
(
args
.
num_workers
):
logger
.
info
(
f
"Creating mocker worker
{
worker_id
+
1
}
/
{
args
.
num_workers
}
"
)
logger
.
info
(
f
"Creating mocker worker
{
worker_id
+
1
}
/
{
args
.
num_workers
}
"
)
...
@@ -93,13 +103,30 @@ async def launch_workers(args, extra_engine_args_path):
...
@@ -93,13 +103,30 @@ async def launch_workers(args, extra_engine_args_path):
runtime
=
DistributedRuntime
(
loop
,
args
.
store_kv
,
args
.
request_plane
)
runtime
=
DistributedRuntime
(
loop
,
args
.
store_kv
,
args
.
request_plane
)
runtimes
.
append
(
runtime
)
runtimes
.
append
(
runtime
)
# Determine which engine args file to use
if
args
.
bootstrap_ports_list
:
# Create per-worker temp file with this worker's bootstrap_port
worker_args
=
base_engine_args
.
copy
()
worker_args
[
"bootstrap_port"
]
=
args
.
bootstrap_ports_list
[
worker_id
]
with
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".json"
,
delete
=
False
)
as
f
:
json
.
dump
(
worker_args
,
f
)
worker_engine_args_path
=
Path
(
f
.
name
)
per_worker_temp_files
.
append
(
worker_engine_args_path
)
logger
.
debug
(
f
"Worker
{
worker_id
}
: using bootstrap_port
{
args
.
bootstrap_ports_list
[
worker_id
]
}
"
)
else
:
worker_engine_args_path
=
extra_engine_args_path
# Create EntrypointArgs for this worker
# Create EntrypointArgs for this worker
entrypoint_args
=
EntrypointArgs
(
entrypoint_args
=
EntrypointArgs
(
engine_type
=
EngineType
.
Mocker
,
engine_type
=
EngineType
.
Mocker
,
model_path
=
args
.
model_path
,
model_path
=
args
.
model_path
,
model_name
=
args
.
model_name
,
model_name
=
args
.
model_name
,
endpoint_id
=
args
.
endpoint
,
endpoint_id
=
args
.
endpoint
,
extra_engine_args
=
extra
_engine_args_path
,
extra_engine_args
=
worker
_engine_args_path
,
is_prefill
=
args
.
is_prefill_worker
,
is_prefill
=
args
.
is_prefill_worker
,
)
)
...
@@ -130,6 +157,13 @@ async def launch_workers(args, extra_engine_args_path):
...
@@ -130,6 +157,13 @@ async def launch_workers(args, extra_engine_args_path):
for
runtime
in
runtimes
:
for
runtime
in
runtimes
:
runtime
.
shutdown
()
runtime
.
shutdown
()
# Clean up per-worker temp files
for
temp_file
in
per_worker_temp_files
:
try
:
temp_file
.
unlink
()
except
Exception
:
pass
def
main
():
def
main
():
uvloop
.
run
(
worker
())
uvloop
.
run
(
worker
())
...
...
components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py
View file @
0980b27f
...
@@ -96,21 +96,6 @@ class PrefillWorkerHandler(BaseWorkerHandler):
...
@@ -96,21 +96,6 @@ class PrefillWorkerHandler(BaseWorkerHandler):
bootstrap_room
=
self
.
_generate_bootstrap_room
()
bootstrap_room
=
self
.
_generate_bootstrap_room
()
logging
.
debug
(
f
"Generated bootstrap_room locally:
{
bootstrap_room
}
"
)
logging
.
debug
(
f
"Generated bootstrap_room locally:
{
bootstrap_room
}
"
)
bootstrap_info
=
{
"bootstrap_host"
:
self
.
bootstrap_host
,
"bootstrap_port"
:
self
.
bootstrap_port
,
"bootstrap_room"
:
bootstrap_room
,
}
# Yield in LLMEngineOutput format for PrefillRouter compatibility
# The disaggregated_params field contains the bootstrap info
yield
{
"token_ids"
:
[],
"text"
:
None
,
"finish_reason"
:
None
,
"disaggregated_params"
:
bootstrap_info
,
}
input_param
=
self
.
_get_input_param
(
inner_request
)
input_param
=
self
.
_get_input_param
(
inner_request
)
# Propagate trace context to SGLang
# Propagate trace context to SGLang
...
...
lib/bindings/c/src/lib.rs
View file @
0980b27f
...
@@ -1080,6 +1080,9 @@ pub fn add_query_instance_id(
...
@@ -1080,6 +1080,9 @@ pub fn add_query_instance_id(
///
///
/// For disaggregated mode: sets `prefill_worker_id` and `decode_worker_id`
/// For disaggregated mode: sets `prefill_worker_id` and `decode_worker_id`
/// For aggregated mode: sets `backend_instance_id` (when both IDs are the same)
/// For aggregated mode: sets `backend_instance_id` (when both IDs are the same)
///
/// Also sets `enable_local_updates: false` since the external caller (EPP/GAIE)
/// will handle bookkeeping via C FFI functions.
pub
fn
set_worker_ids_for_stage2
(
pub
fn
set_worker_ids_for_stage2
(
request
:
&
mut
NvCreateChatCompletionRequest
,
request
:
&
mut
NvCreateChatCompletionRequest
,
decode_worker_id
:
Option
<
i64
>
,
decode_worker_id
:
Option
<
i64
>
,
...
@@ -1091,6 +1094,9 @@ pub fn set_worker_ids_for_stage2(
...
@@ -1091,6 +1094,9 @@ pub fn set_worker_ids_for_stage2(
.expect
(
"NvExt builder should not fail"
)
.expect
(
"NvExt builder should not fail"
)
});
});
// Disable local updates - external caller handles bookkeeping via C FFI
nvext
.enable_local_updates
=
Some
(
false
);
// Check if this is aggregated mode (same worker for both)
// Check if this is aggregated mode (same worker for both)
let
is_aggregated
=
prefill_worker_id
==
decode_worker_id
;
let
is_aggregated
=
prefill_worker_id
==
decode_worker_id
;
...
...
lib/llm/src/kv_router.rs
View file @
0980b27f
...
@@ -354,71 +354,76 @@ impl KvRouter {
...
@@ -354,71 +354,76 @@ impl KvRouter {
tracing
::
info!
(
"Worker query client initialized"
);
tracing
::
info!
(
"Worker query client initialized"
);
// Start KV event subscriber background process (only when use_kv_events is enabled)
// Start KV event subscriber background process (only when use_kv_events is enabled)
//
This is spawned as a background task to avoid blocking router startup.
//
We block here until at least one worker runtime config is registered,
//
T
he
task waits for runtime_configs to determine whether to use NATS Core or JetStream
.
//
t
he
n spawn the subscriber. This ensures the router is ready before accepting requests
.
if
kv_router_config
.use_kv_events
if
kv_router_config
.use_kv_events
&&
let
Indexer
::
KvIndexer
(
ref
kv_indexer
)
=
indexer
&&
let
Indexer
::
KvIndexer
(
ref
kv_indexer
)
=
indexer
{
{
// Clone everything needed for the background task
let
component_clone
=
component
.clone
();
let
kv_indexer_clone
=
kv_indexer
.clone
();
let
cancellation_token_clone
=
cancellation_token
.clone
();
let
mut
runtime_configs_rx_clone
=
runtime_configs_rx
.clone
();
let
mut
runtime_configs_rx_clone
=
runtime_configs_rx
.clone
();
let
worker_query_client_clone
=
worker_query
::
WorkerQueryClient
::
new
(
component
.clone
(),
runtime_configs_rx
.clone
());
tokio
::
spawn
(
async
move
{
// Wait for at least one worker runtime config to be registered
// Wait for runtime_configs to have at least one entry
tracing
::
info!
(
"Waiting for at least one worker runtime config to be registered..."
);
let
(
all_local_indexer
,
count
)
=
loop
{
let
(
all_local_indexer
,
count
)
=
loop
{
{
{
let
configs
=
runtime_configs_rx_clone
.borrow
();
let
configs
=
runtime_configs_rx_clone
.borrow
();
if
!
configs
.is_empty
()
{
if
!
configs
.is_empty
()
{
let
all_local_indexer
=
let
all_local_indexer
=
configs
.values
()
.all
(|
c
|
c
.enable_local_indexer
);
configs
.values
()
.all
(|
c
|
c
.enable_local_indexer
);
break
(
all_local_indexer
,
configs
.len
());
break
(
all_local_indexer
,
configs
.len
());
}
}
}
}
// Wait for changes to runtime_configs
// Wait for changes to runtime_configs
tokio
::
select!
{
tokio
::
select!
{
_
=
cancellation_token_clone
.cancelled
()
=>
{
_
=
cancellation_token
.cancelled
()
=>
{
tracing
::
debug!
(
"Subscriber selection task cancelled"
);
tracing
::
debug!
(
"KvRouter startup cancelled while waiting for workers"
);
return
;
anyhow
::
bail!
(
"KvRouter startup cancelled"
);
}
}
result
=
runtime_configs_rx_clone
.changed
()
=>
{
result
=
runtime_configs_rx_clone
.changed
()
=>
{
if
result
.is_err
()
{
if
result
.is_err
()
{
tracing
::
debug!
(
"Runtime configs channel closed"
);
tracing
::
debug!
(
"Runtime configs channel closed"
);
return
;
anyhow
::
bail!
(
"Runtime configs channel closed before any workers registered"
);
}
}
}
}
}
};
}
};
tracing
::
info!
(
"Found {count} worker runtime config(s), starting KV event subscriber"
);
if
all_local_indexer
{
// Clone everything needed for the background subscriber task
// All workers have local_indexer enabled - use NATS Core
let
component_clone
=
component
.clone
();
tracing
::
info!
(
let
kv_indexer_clone
=
kv_indexer
.clone
();
"All {count} workers have local_indexer enabled, using NATS Core subscription"
let
cancellation_token_clone
=
cancellation_token
.clone
();
);
let
worker_query_client_clone
=
worker_query
::
WorkerQueryClient
::
new
(
component
.clone
(),
runtime_configs_rx
.clone
());
// Spawn subscriber as background task (long-running)
if
all_local_indexer
{
// All workers have local_indexer enabled - use NATS Core
tracing
::
info!
(
"All {count} workers have local_indexer enabled, using NATS Core subscription"
);
tokio
::
spawn
(
async
move
{
if
let
Err
(
e
)
=
start_kv_router_background_nats_core
(
if
let
Err
(
e
)
=
start_kv_router_background_nats_core
(
component_clone
.clone
()
,
component_clone
,
kv_indexer_clone
.event_sender
(),
kv_indexer_clone
.event_sender
(),
kv_indexer_clone
.remove_worker_sender
(),
kv_indexer_clone
.remove_worker_sender
(),
cancellation_token_clone
.clone
()
,
cancellation_token_clone
,
worker_query_client_clone
,
worker_query_client_clone
,
)
)
.await
.await
{
{
tracing
::
error!
(
"Failed to start NATS Core subscriber: {e}"
);
tracing
::
error!
(
"Failed to start NATS Core subscriber: {e}"
);
}
}
}
else
{
});
// Not all workers have local_indexer - use JetStream
}
else
{
tracing
::
info!
(
// Not all workers have local_indexer - use JetStream
"Not all workers have local_indexer enabled, using JetStream subscription"
tracing
::
info!
(
);
"Not all workers have local_indexer enabled, using JetStream subscription"
);
tokio
::
spawn
(
async
move
{
if
let
Err
(
e
)
=
start_kv_router_background
(
if
let
Err
(
e
)
=
start_kv_router_background
(
component_clone
.clone
()
,
component_clone
,
consumer_id
,
consumer_id
,
kv_indexer_clone
.event_sender
(),
kv_indexer_clone
.event_sender
(),
kv_indexer_clone
.remove_worker_sender
(),
kv_indexer_clone
.remove_worker_sender
(),
...
@@ -428,7 +433,7 @@ impl KvRouter {
...
@@ -428,7 +433,7 @@ impl KvRouter {
kv_router_config
kv_router_config
.router_snapshot_threshold
.router_snapshot_threshold
.map
(|
_
|
kv_indexer_clone
.snapshot_event_sender
()),
.map
(|
_
|
kv_indexer_clone
.snapshot_event_sender
()),
cancellation_token_clone
.clone
()
,
cancellation_token_clone
,
kv_router_config
.router_snapshot_threshold
,
kv_router_config
.router_snapshot_threshold
,
kv_router_config
.router_reset_states
,
kv_router_config
.router_reset_states
,
)
)
...
@@ -436,8 +441,8 @@ impl KvRouter {
...
@@ -436,8 +441,8 @@ impl KvRouter {
{
{
tracing
::
error!
(
"Failed to start JetStream subscriber: {e}"
);
tracing
::
error!
(
"Failed to start JetStream subscriber: {e}"
);
}
}
}
}
);
}
);
}
}
}
tracing
::
info!
(
"KV Routing initialized"
);
tracing
::
info!
(
"KV Routing initialized"
);
...
@@ -815,17 +820,12 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
...
@@ -815,17 +820,12 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let
is_query_only
=
request
.get_annotation_value
(
"query_instance_id"
)
.is_some
();
let
is_query_only
=
request
.get_annotation_value
(
"query_instance_id"
)
.is_some
();
// Determine if this router should handle local state updates (add_request, free, etc.)
// Determine if this router should handle local state updates (add_request, free, etc.)
// Only skip local updates for GAIE Stage 2: when BOTH prefill and decode worker IDs
// Default is true (router handles bookkeeping). Set to false for GAIE Stage 2 where
// are externally specified (indicates external orchestrator handles tracking).
// an external orchestrator (e.g., EPP sidecar) handles bookkeeping via C FFI.
// For internal routing (e.g., bootstrap optimization with only prefill_worker_id set),
let
handle_local_updates
=
request
// we still handle updates locally.
.routing
let
routing
=
request
.routing
.as_ref
();
.as_ref
()
let
handle_local_updates
=
routing
.and_then
(|
r
|
r
.enable_local_updates
)
.map
(|
r
|
{
// GAIE Stage 2 sets both worker IDs - external caller handles tracking
// All other cases (including backend_instance_id for routing) - we handle locally
r
.prefill_worker_id
.is_none
()
||
r
.decode_worker_id
.is_none
()
})
.unwrap_or
(
true
);
.unwrap_or
(
true
);
// Get phase from tracker (defaults to Aggregated if no tracker or phase not set)
// Get phase from tracker (defaults to Aggregated if no tracker or phase not set)
...
@@ -917,9 +917,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
...
@@ -917,9 +917,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let
stream_context
=
response_stream
.context
();
let
stream_context
=
response_stream
.context
();
let
context_for_monitoring
=
stream_context
.clone
();
let
context_for_monitoring
=
stream_context
.clone
();
//
TODO: When handle_local_updates=false, consider moving
mark_prefill_completed
//
Wrap stream with lifecycle management (
mark_prefill_completed
, free)
//
to an external caller (e.g., sidecar) if they support a first-token hook
.
//
Only perform these operations if handle_local_updates is true
.
//
Currently mark_prefill_completed is called here for all flows
.
//
When false, an external caller (e.g., GAIE sidecar) handles bookkeeping via C FFI
.
let
wrapped_stream
=
Box
::
pin
(
async_stream
::
stream!
{
let
wrapped_stream
=
Box
::
pin
(
async_stream
::
stream!
{
let
mut
prefill_marked
=
false
;
let
mut
prefill_marked
=
false
;
...
@@ -937,7 +937,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
...
@@ -937,7 +937,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
break
;
break
;
};
};
if
!
prefill_marked
{
if
handle_local_updates
&&
!
prefill_marked
{
// Only mark prefill completed when we receive actual tokens,
// Only mark prefill completed when we receive actual tokens,
// not empty bootstrap info (token_ids: []) from disaggregated prefill
// not empty bootstrap info (token_ids: []) from disaggregated prefill
let
has_tokens
=
item
.data
.as_ref
()
let
has_tokens
=
item
.data
.as_ref
()
...
@@ -956,8 +956,11 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
...
@@ -956,8 +956,11 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
}
}
}
}
// Always call free() - it's idempotent and safe even if already freed or never added
// Only call free() if we handle local updates.
if
let
Err
(
e
)
=
chooser
.free
(
&
context_id
)
.await
{
// When handle_local_updates=false, external caller handles cleanup via C FFI.
if
handle_local_updates
&&
let
Err
(
e
)
=
chooser
.free
(
&
context_id
)
.await
{
tracing
::
warn!
(
"Failed to free request {context_id}: {e}"
);
tracing
::
warn!
(
"Failed to free request {context_id}: {e}"
);
}
}
});
});
...
...
lib/llm/src/kv_router/prefill_router.rs
View file @
0980b27f
...
@@ -6,7 +6,7 @@ use std::sync::{Arc, OnceLock};
...
@@ -6,7 +6,7 @@ use std::sync::{Arc, OnceLock};
use
anyhow
::
Result
;
use
anyhow
::
Result
;
use
futures
::
StreamExt
;
use
futures
::
StreamExt
;
use
rand
::
Rng
;
use
rand
::
Rng
;
use
tokio
::
sync
::
oneshot
;
use
tokio
::
sync
::
{
OwnedSemaphorePermit
,
oneshot
}
;
use
tokio_util
::
sync
::
CancellationToken
;
use
tokio_util
::
sync
::
CancellationToken
;
use
dynamo_runtime
::{
use
dynamo_runtime
::{
...
@@ -24,7 +24,6 @@ use crate::{
...
@@ -24,7 +24,6 @@ use crate::{
protocols
::
common
::
llm_backend
::{
LLMEngineOutput
,
PreprocessedRequest
},
protocols
::
common
::
llm_backend
::{
LLMEngineOutput
,
PreprocessedRequest
},
protocols
::
common
::
preprocessor
::{
BootstrapInfo
,
PrefillResult
},
protocols
::
common
::
preprocessor
::{
BootstrapInfo
,
PrefillResult
},
protocols
::
common
::
timing
::{
RequestPhase
,
RequestTracker
},
protocols
::
common
::
timing
::{
RequestPhase
,
RequestTracker
},
protocols
::
openai
::
nvext
::
WorkerIdInfo
,
};
};
/// Errors that can occur during prefill routing
/// Errors that can occur during prefill routing
...
@@ -85,10 +84,10 @@ impl InnerPrefillRouter {
...
@@ -85,10 +84,10 @@ impl InnerPrefillRouter {
/// It optionally calls a prefill worker before routing to decode, extracting disaggregated_params
/// It optionally calls a prefill worker before routing to decode, extracting disaggregated_params
/// from the prefill response and injecting them into the decode request.
/// from the prefill response and injecting them into the decode request.
///
///
///
Supports regular Dynamo and GAIE integrated mode via query_instance_id state machine
:
///
Modes
:
/// -
GAIE Stage 1
: query_instance_id
transitions "" -> "prefill" -> "decode",
returns
only
worker IDs
/// -
Query-only
:
`
query_instance_id
` annotation present →
returns worker IDs
without execution
/// -
GAIE Stage 2: routing.
prefill_worker_id
/routing.
decode_worker_id
are set, full execution with
specified workers
/// -
Pre-routed: `
prefill_worker_id
`/`
decode_worker_id
` set → routes to
specified workers
/// - No
n-GAIE: like GAIE Stage 2 but the worker ids have to be determined.
/// - No
rmal: Worker IDs determined by router based on KV cache state
pub
struct
PrefillRouter
{
pub
struct
PrefillRouter
{
prefill_router
:
OnceLock
<
InnerPrefillRouter
>
,
prefill_router
:
OnceLock
<
InnerPrefillRouter
>
,
model_manager
:
Arc
<
ModelManager
>
,
model_manager
:
Arc
<
ModelManager
>
,
...
@@ -232,11 +231,6 @@ impl PrefillRouter {
...
@@ -232,11 +231,6 @@ impl PrefillRouter {
Ok
(())
Ok
(())
}
}
/// Generate a unique bootstrap room ID for disaggregated serving
fn
generate_bootstrap_room
()
->
u64
{
rand
::
rng
()
.random
()
}
/// Build bootstrap_info for disaggregated serving
/// Build bootstrap_info for disaggregated serving
/// If preselected_worker is provided (GAIE Stage 2), use it directly.
/// If preselected_worker is provided (GAIE Stage 2), use it directly.
/// Otherwise, query for the best worker (KV mode) or select next worker (non-KV modes).
/// Otherwise, query for the best worker (KV mode) or select next worker (non-KV modes).
...
@@ -250,7 +244,6 @@ impl PrefillRouter {
...
@@ -250,7 +244,6 @@ impl PrefillRouter {
// Worker selection
// Worker selection
let
(
worker_id
,
dp_rank
)
=
if
let
Some
(
id
)
=
preselected_worker
{
let
(
worker_id
,
dp_rank
)
=
if
let
Some
(
id
)
=
preselected_worker
{
// GAIE Stage 2: use pre-selected worker
let
dp_rank
=
req
.routing
.as_ref
()
.and_then
(|
r
|
r
.dp_rank
)
.unwrap_or
(
0
);
let
dp_rank
=
req
.routing
.as_ref
()
.and_then
(|
r
|
r
.dp_rank
)
.unwrap_or
(
0
);
tracing
::
debug!
(
tracing
::
debug!
(
worker_id
=
id
,
worker_id
=
id
,
...
@@ -285,7 +278,7 @@ impl PrefillRouter {
...
@@ -285,7 +278,7 @@ impl PrefillRouter {
let
host
=
endpoint
.bootstrap_host
?
;
let
host
=
endpoint
.bootstrap_host
?
;
let
port
=
endpoint
.bootstrap_port
?
;
let
port
=
endpoint
.bootstrap_port
?
;
let
bootstrap_room
=
Self
::
generate_bootstrap_ro
om
();
let
bootstrap_room
:
u64
=
rand
::
rng
()
.rand
om
();
tracing
::
info!
(
tracing
::
info!
(
worker_id
=
worker_id
,
worker_id
=
worker_id
,
...
@@ -308,12 +301,18 @@ impl PrefillRouter {
...
@@ -308,12 +301,18 @@ impl PrefillRouter {
))
))
}
}
/// Execute prefill with the given router and extract structured result
/// Execute prefill with the given router and extract structured result.
/// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization)
///
/// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization).
///
/// If `phase_permit` is provided, it is dropped after the first output is received,
/// allowing subsequent `set_phase` calls to proceed. This is used in the bootstrap
/// optimization path to ensure `record_worker` completes before the phase changes.
async
fn
execute_prefill
(
async
fn
execute_prefill
(
router
:
Option
<
InnerPrefillRouter
>
,
router
:
Option
<
InnerPrefillRouter
>
,
request
:
SingleIn
<
PreprocessedRequest
>
,
request
:
SingleIn
<
PreprocessedRequest
>
,
target_worker
:
Option
<
u64
>
,
target_worker
:
Option
<
u64
>
,
phase_permit
:
Option
<
OwnedSemaphorePermit
>
,
)
->
Result
<
(
PrefillResult
,
Option
<
u64
>
),
PrefillError
>
{
)
->
Result
<
(
PrefillResult
,
Option
<
u64
>
),
PrefillError
>
{
let
router
=
router
.ok_or
(
PrefillError
::
NotActivated
)
?
;
let
router
=
router
.ok_or
(
PrefillError
::
NotActivated
)
?
;
let
mut
prefill_response
=
router
let
mut
prefill_response
=
router
...
@@ -321,6 +320,10 @@ impl PrefillRouter {
...
@@ -321,6 +320,10 @@ impl PrefillRouter {
.await
.await
.map_err
(|
e
|
PrefillError
::
PrefillError
(
e
.to_string
()))
?
;
.map_err
(|
e
|
PrefillError
::
PrefillError
(
e
.to_string
()))
?
;
// Drop phase permit now - routing is complete, record_worker was called in select_worker.
// This unblocks set_phase(Decode) in the main task without waiting for prefill output.
drop
(
phase_permit
);
let
Some
(
first_output
)
=
prefill_response
.next
()
.await
else
{
let
Some
(
first_output
)
=
prefill_response
.next
()
.await
else
{
return
Err
(
PrefillError
::
PrefillError
(
return
Err
(
PrefillError
::
PrefillError
(
"Prefill router returned no output (stream ended)"
.to_string
(),
"Prefill router returned no output (stream ended)"
.to_string
(),
...
@@ -379,17 +382,24 @@ impl PrefillRouter {
...
@@ -379,17 +382,24 @@ impl PrefillRouter {
))
))
}
}
/// Spawn prefill as a background task
/// Spawn prefill as a background task.
/// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization)
///
/// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization).
///
/// The `phase_permit` is passed to the spawned task and dropped after the first output,
/// allowing the main task's `set_phase(Decode)` to proceed.
fn
spawn_prefill_task
(
fn
spawn_prefill_task
(
&
self
,
&
self
,
prefill_request
:
SingleIn
<
PreprocessedRequest
>
,
prefill_request
:
SingleIn
<
PreprocessedRequest
>
,
target_worker
:
Option
<
u64
>
,
target_worker
:
Option
<
u64
>
,
phase_permit
:
OwnedSemaphorePermit
,
)
{
)
{
let
router
=
self
.prefill_router
.get
()
.cloned
();
let
router
=
self
.prefill_router
.get
()
.cloned
();
tokio
::
spawn
(
async
move
{
tokio
::
spawn
(
async
move
{
match
Self
::
execute_prefill
(
router
,
prefill_request
,
target_worker
)
.await
{
match
Self
::
execute_prefill
(
router
,
prefill_request
,
target_worker
,
Some
(
phase_permit
))
.await
{
Ok
(
_
)
=>
{
Ok
(
_
)
=>
{
tracing
::
debug!
(
"Prefill background task completed"
);
tracing
::
debug!
(
"Prefill background task completed"
);
}
}
...
@@ -400,67 +410,17 @@ impl PrefillRouter {
...
@@ -400,67 +410,17 @@ impl PrefillRouter {
});
});
}
}
/// Call the prefill router and extract structured prefill result and worker ID
/// Call the prefill router and extract structured prefill result and worker ID.
///
/// This is the synchronous prefill path - we wait for prefill to complete before proceeding.
/// No phase permit is needed since `record_worker` completes before we return.
async
fn
call_prefill
(
async
fn
call_prefill
(
&
self
,
&
self
,
request
:
SingleIn
<
PreprocessedRequest
>
,
request
:
SingleIn
<
PreprocessedRequest
>
,
)
->
Result
<
(
PrefillResult
,
Option
<
u64
>
),
PrefillError
>
{
)
->
Result
<
(
PrefillResult
,
Option
<
u64
>
),
PrefillError
>
{
// For call_prefill path, routing is handled by the router itself (no direct routing needed)
// For call_prefill path, routing is handled by the router itself (no direct routing needed)
Self
::
execute_prefill
(
self
.prefill_router
.get
()
.cloned
(),
request
,
None
)
.await
// No phase permit needed - we wait for completion before changing phase
}
Self
::
execute_prefill
(
self
.prefill_router
.get
()
.cloned
(),
request
,
None
,
None
)
.await
}
/// GAIE helper functions for preparing prefill requests
impl
PrefillRouter
{
/// Prepare prefill request for GAIE flows
/// - Stage 1: Sets query_instance_id:prefill annotation
/// - Stage 2: Sets backend_instance_id to target prefill worker
fn
prepare_prefill_for_gaie
(
prefill_req
:
&
mut
PreprocessedRequest
,
is_gaie_stage1
:
bool
)
{
if
is_gaie_stage1
{
// GAIE Stage 1: Set query_instance_id to "prefill" for prefill worker selection
prefill_req
.annotations
.retain
(|
a
|
!
a
.starts_with
(
"query_instance_id"
));
prefill_req
.annotations
.push
(
format!
(
"query_instance_id:{}"
,
RequestPhase
::
Prefill
));
}
else
if
let
Some
(
prefill_worker_id
)
=
prefill_req
.routing
.as_ref
()
.and_then
(|
r
|
r
.prefill_worker_id
)
{
// GAIE Stage 2: Route to pre-selected prefill worker from the stage 1
tracing
::
debug!
(
prefill_worker_id
=
prefill_worker_id
,
"GAIE Stage 2: Routing prefill to pre-selected worker"
);
prefill_req
.routing_mut
()
.backend_instance_id
=
Some
(
prefill_worker_id
);
}
}
/// Prepare decode request for GAIE Stage 1
/// Extracts prefill_worker_id from prefill result and sets decode annotations
fn
prepare_decode_for_gaie_stage1
(
decode_req
:
&
mut
PreprocessedRequest
,
prefill_result
:
&
PrefillResult
,
)
{
let
prefill_worker_id
=
prefill_result
.disaggregated_params
.get
(
"worker_id"
)
.and_then
(|
v
|
serde_json
::
from_value
::
<
WorkerIdInfo
>
(
v
.clone
())
.ok
())
.and_then
(|
info
|
info
.prefill_worker_id
);
if
let
Some
(
worker_id
)
=
prefill_worker_id
{
decode_req
.annotations
.retain
(|
a
|
!
a
.starts_with
(
"query_instance_id"
));
decode_req
.annotations
.push
(
format!
(
"query_instance_id:{}"
,
RequestPhase
::
Decode
));
decode_req
.annotations
.push
(
format!
(
"prefill_worker_id:{worker_id}"
));
}
}
}
}
}
...
@@ -490,22 +450,14 @@ impl
...
@@ -490,22 +450,14 @@ impl
let
request_id
=
context
.id
()
.to_string
();
let
request_id
=
context
.id
()
.to_string
();
let
engine_ctx
=
context
.context
();
let
engine_ctx
=
context
.context
();
// GAIE Stage 1: the presence of the empty query_instance_id signals query-only mode
// State machine: "" -> "prefill" -> "decode" (disagg) OR "" -> aggregated worker (agg fallback)
let
is_gaie_stage1
=
req
.get_annotation_value
(
"query_instance_id"
)
.is_some_and
(|
s
|
s
.is_empty
());
// Save original max_tokens for decode
// Save original max_tokens for decode
let
original_max_tokens
=
req
.stop_conditions.max_tokens
;
let
original_max_tokens
=
req
.stop_conditions.max_tokens
;
// GAIE Stage 1: Check if prefill router is activated - if not, skip to decode
// If prefill router is not activated, skip directly to decode
if
is_gaie_stage1
&&
self
.prefill_router
.get
()
.is_none
()
{
if
self
.prefill_router
.get
()
.is_none
()
{
tracing
::
debug!
(
"GAIE Stage 1: Prefill router not activated, skipping to decode"
);
if
self
.enforce_disagg
{
if
self
.enforce_disagg
{
return
Err
(
anyhow
::
anyhow!
(
PrefillError
::
NotActivated
));
return
Err
(
anyhow
::
anyhow!
(
PrefillError
::
NotActivated
));
}
}
// Fall back to decode-only
return
next
.generate
(
context
.map
(|
_
|
req
))
.await
;
return
next
.generate
(
context
.map
(|
_
|
req
))
.await
;
}
}
...
@@ -515,47 +467,45 @@ impl
...
@@ -515,47 +467,45 @@ impl
req
.tracker
=
Some
(
Arc
::
new
(
RequestTracker
::
new
()));
req
.tracker
=
Some
(
Arc
::
new
(
RequestTracker
::
new
()));
}
}
let
tracker
=
req
.tracker
.as_ref
()
.unwrap
();
let
tracker
=
req
.tracker
.as_ref
()
.unwrap
();
tracker
.set_phase
(
RequestPhase
::
Prefill
);
let
prefill_phase_permit
=
tracker
.set_phase
(
RequestPhase
::
Prefill
)
.await
;
tracker
.record_prefill_start
();
tracker
.record_prefill_start
();
// Prepare prefill request with max_tokens = 1 (clone after tracker is set)
// Prepare prefill request with max_tokens = 1 (clone after tracker is set)
let
mut
prefill_req
=
req
.clone
();
let
mut
prefill_req
=
req
.clone
();
prefill_req
.stop_conditions.max_tokens
=
Some
(
1
);
prefill_req
.stop_conditions.max_tokens
=
Some
(
1
);
// Prepare prefill request for GAIE flows (Stage 1 or Stage 2)
// Try build_bootstrap_info optimization: if we can get bootstrap info upfront,
Self
::
prepare_prefill_for_gaie
(
&
mut
prefill_req
,
is_gaie_stage1
);
// spawn prefill in background and proceed to decode immediately.
// Try build_bootstrap_info optimization (skip for GAIE Stage 1 which needs query-only flow)
// For GAIE Stage 2, use prefill_worker_id if provided
let
preselected_worker
=
prefill_req
let
preselected_worker
=
prefill_req
.routing
.routing
.as_ref
()
.as_ref
()
.and_then
(|
r
|
r
.prefill_worker_id
);
.and_then
(|
r
|
r
.prefill_worker_id
);
let
prefill_result
=
if
!
is_gaie_stage1
let
prefill_result
=
if
let
Some
((
worker_id
,
dp_rank
,
bootstrap_info
))
=
self
&&
let
Some
((
worker_id
,
dp_rank
,
bootstrap_info
))
=
self
.build_bootstrap_info
(
&
prefill_req
,
preselected_worker
)
.build_bootstrap_info
(
&
prefill_req
,
preselected_worker
)
.await
.await
{
{
// Bootstrap optimization path: spawn prefill in background
// Bootstrap optimization path: spawn prefill in background
let
routing
=
prefill_req
.routing_mut
();
let
routing
=
prefill_req
.routing_mut
();
routing
.prefill_worker_id
=
Some
(
worker_id
);
routing
.prefill_worker_id
=
Some
(
worker_id
);
routing
.backend_instance_id
=
Some
(
worker_id
);
// Route prefill to the SAME worker we got bootstrap_info from
routing
.dp_rank
=
Some
(
dp_rank
);
routing
.dp_rank
=
Some
(
dp_rank
);
prefill_req
.bootstrap_info
=
Some
(
bootstrap_info
.clone
());
prefill_req
.bootstrap_info
=
Some
(
bootstrap_info
.clone
());
let
prefill_context
=
Context
::
with_id
(
prefill_req
,
request_id
.clone
());
let
prefill_context
=
Context
::
with_id
(
prefill_req
,
request_id
.clone
());
engine_ctx
.link_child
(
prefill_context
.context
());
engine_ctx
.link_child
(
prefill_context
.context
());
self
.spawn_prefill_task
(
prefill_context
,
Some
(
worker_id
));
// Pass phase permit to spawned task - it drops after first output (record_worker complete)
// This allows set_phase(Decode) below to proceed only after prefill routing is done
self
.spawn_prefill_task
(
prefill_context
,
Some
(
worker_id
),
prefill_phase_permit
);
Ok
((
None
,
Some
(
worker_id
),
Some
(
bootstrap_info
)))
Ok
((
None
,
Some
(
worker_id
),
Some
(
bootstrap_info
)))
}
else
{
}
else
{
// Original prefill path: wait for prefill to complete
// Original prefill path: wait for prefill to complete
tracing
::
debug!
(
tracing
::
debug!
(
"Using original prefill path"
);
is_gaie_stage1
=
is_gaie_stage1
,
"Using original prefill path"
// Drop the phase permit before calling call_prefill - we wait for completion
);
// so there's no race with set_phase(Decode) below
drop
(
prefill_phase_permit
);
let
prefill_context
=
Context
::
with_id
(
prefill_req
,
request_id
.clone
());
let
prefill_context
=
Context
::
with_id
(
prefill_req
,
request_id
.clone
());
engine_ctx
.link_child
(
prefill_context
.context
());
engine_ctx
.link_child
(
prefill_context
.context
());
...
@@ -579,20 +529,18 @@ impl
...
@@ -579,20 +529,18 @@ impl
Ok
((
maybe_prefill_result
,
_
prefill_worker_id
,
bootstrap_info
))
=>
{
Ok
((
maybe_prefill_result
,
_
prefill_worker_id
,
bootstrap_info
))
=>
{
tracing
::
debug!
(
"Prefill completed, proceeding to decode"
);
tracing
::
debug!
(
"Prefill completed, proceeding to decode"
);
// Set phase to Decode for the decode request
// Set phase to Decode for the decode request.
// In bootstrap path, this blocks until the spawned prefill task drops its permit
// (after first output / record_worker completes), ensuring correct phase for routing.
if
let
Some
(
ref
tracker
)
=
req
.tracker
{
if
let
Some
(
ref
tracker
)
=
req
.tracker
{
tracker
.set_phase
(
RequestPhase
::
Decode
);
let
_
decode_permit
=
tracker
.set_phase
(
RequestPhase
::
Decode
)
.await
;
// Permit is dropped immediately - decode proceeds, no need to hold it
}
}
let
mut
decode_req
=
req
;
let
mut
decode_req
=
req
;
// Update request with prefill result
// Update request with prefill result
if
is_gaie_stage1
{
if
let
Some
(
prefill_result
)
=
maybe_prefill_result
{
if
let
Some
(
ref
prefill_result
)
=
maybe_prefill_result
{
Self
::
prepare_decode_for_gaie_stage1
(
&
mut
decode_req
,
prefill_result
);
}
}
else
if
let
Some
(
prefill_result
)
=
maybe_prefill_result
{
// Normal or GAIE Stage 2: Set prefill_result for decode
decode_req
.prefill_result
=
Some
(
prefill_result
);
decode_req
.prefill_result
=
Some
(
prefill_result
);
}
}
...
@@ -611,17 +559,6 @@ impl
...
@@ -611,17 +559,6 @@ impl
..
existing_override
.unwrap_or_default
()
..
existing_override
.unwrap_or_default
()
});
});
// GAIE Stage 2: Route to pre-selected decode worker if specified
if
let
Some
(
decode_worker_id
)
=
decode_req
.routing
.as_ref
()
.and_then
(|
r
|
r
.decode_worker_id
)
{
decode_req
.routing_mut
()
.backend_instance_id
=
Some
(
decode_worker_id
);
tracing
::
debug!
(
decode_worker_id
=
decode_worker_id
,
"GAIE Stage 2: Routing decode to pre-selected worker"
);
}
// Map the modified request through with preserved context
// Map the modified request through with preserved context
let
decode_request
=
context
.map
(|
_
|
decode_req
);
let
decode_request
=
context
.map
(|
_
|
decode_req
);
next
.generate
(
decode_request
)
.await
next
.generate
(
decode_request
)
.await
...
...
lib/llm/src/kv_router/subscriber.rs
View file @
0980b27f
...
@@ -651,6 +651,51 @@ pub async fn start_kv_router_background(
...
@@ -651,6 +651,51 @@ pub async fn start_kv_router_background(
Ok
(())
Ok
(())
}
}
/// Handle a worker discovery event (added or removed).
async
fn
handle_worker_discovery
(
event
:
DiscoveryEvent
,
worker_query_client
:
&
WorkerQueryClient
,
kv_events_tx
:
&
mpsc
::
Sender
<
RouterEvent
>
,
remove_worker_tx
:
&
mpsc
::
Sender
<
WorkerId
>
,
)
{
match
event
{
DiscoveryEvent
::
Added
(
instance
)
=>
{
let
worker_id
=
instance
.instance_id
();
tracing
::
info!
(
"DISCOVERY: Worker {worker_id} added, dumping local indexer into router"
);
match
recover_from_worker
(
worker_query_client
,
worker_id
,
None
,
// Start from beginning
None
,
// Get all events
kv_events_tx
,
)
.await
{
Ok
(
count
)
=>
{
tracing
::
info!
(
"Successfully dumped worker {worker_id}'s local indexer, recovered {count} events"
);
}
Err
(
e
)
=>
{
tracing
::
warn!
(
"Failed to dump worker {worker_id}'s local indexer (may not have local indexer enabled): {e}"
);
}
}
}
DiscoveryEvent
::
Removed
(
worker_id
)
=>
{
tracing
::
warn!
(
"DISCOVERY: Worker {worker_id} removed, removing from router indexer"
);
if
let
Err
(
e
)
=
remove_worker_tx
.send
(
worker_id
)
.await
{
tracing
::
warn!
(
"Failed to send worker removal for worker {worker_id}: {e}"
);
}
}
}
}
/// Start a simplified background task for event consumption using NATS Core.
/// Start a simplified background task for event consumption using NATS Core.
///
///
/// This is used when local indexer mode is enabled. Unlike `start_kv_router_background`,
/// This is used when local indexer mode is enabled. Unlike `start_kv_router_background`,
...
@@ -660,6 +705,9 @@ pub async fn start_kv_router_background(
...
@@ -660,6 +705,9 @@ pub async fn start_kv_router_background(
/// - On worker Added: dumps worker's local indexer into router
/// - On worker Added: dumps worker's local indexer into router
/// - On worker Removed: removes worker from router indexer
/// - On worker Removed: removes worker from router indexer
///
///
/// This function first recovers state from all currently registered workers before
/// spawning the background task, ensuring the router is ready before returning.
///
/// This is appropriate when workers have local indexers enabled.
/// This is appropriate when workers have local indexers enabled.
pub
async
fn
start_kv_router_background_nats_core
(
pub
async
fn
start_kv_router_background_nats_core
(
component
:
Component
,
component
:
Component
,
...
@@ -688,6 +736,40 @@ pub async fn start_kv_router_background_nats_core(
...
@@ -688,6 +736,40 @@ pub async fn start_kv_router_background_nats_core(
.list_and_watch
(
generate_discovery_key
,
Some
(
cancellation_token
.clone
()))
.list_and_watch
(
generate_discovery_key
,
Some
(
cancellation_token
.clone
()))
.await
?
;
.await
?
;
// Drain and process all existing workers before spawning the background loop.
// list_and_watch returns existing instances first, so we poll with a short timeout
// to process all initial workers synchronously before the router becomes "ready".
loop
{
// Use a short timeout to detect when initial discovery events are exhausted
let
poll_result
=
tokio
::
time
::
timeout
(
Duration
::
from_millis
(
100
),
instance_event_stream
.next
())
.await
;
match
poll_result
{
Ok
(
Some
(
Ok
(
event
)))
=>
{
handle_worker_discovery
(
event
,
&
worker_query_client
,
&
kv_events_tx
,
&
remove_worker_tx
,
)
.await
;
}
Ok
(
Some
(
Err
(
e
)))
=>
{
tracing
::
warn!
(
"Error receiving discovery event during initial sync: {e}"
);
}
Ok
(
None
)
=>
{
// Stream ended
tracing
::
warn!
(
"Discovery stream ended during initial sync"
);
break
;
}
Err
(
_
)
=>
{
// Timeout - no more initial events
tracing
::
debug!
(
"Initial worker discovery sync complete"
);
break
;
}
}
}
tokio
::
spawn
(
async
move
{
tokio
::
spawn
(
async
move
{
// Track last received event ID per worker for gap detection
// Track last received event ID per worker for gap detection
let
mut
last_event_ids
:
HashMap
<
WorkerId
,
u64
>
=
HashMap
::
new
();
let
mut
last_event_ids
:
HashMap
<
WorkerId
,
u64
>
=
HashMap
::
new
();
...
@@ -703,51 +785,17 @@ pub async fn start_kv_router_background_nats_core(
...
@@ -703,51 +785,17 @@ pub async fn start_kv_router_background_nats_core(
// Handle generate endpoint instance add/remove events
// Handle generate endpoint instance add/remove events
Some
(
discovery_event_result
)
=
instance_event_stream
.next
()
=>
{
Some
(
discovery_event_result
)
=
instance_event_stream
.next
()
=>
{
let
Ok
(
discovery_
event
)
=
discovery_event_result
else
{
let
Ok
(
event
)
=
discovery_event_result
else
{
continue
;
continue
;
};
};
match
discovery_event
{
handle_worker_discovery
(
DiscoveryEvent
::
Added
(
_
instance
)
=>
{
event
,
// Extract worker_id from the instance
&
worker_query_client
,
let
worker_id
=
_
instance
.instance_id
();
&
kv_events_tx
,
&
remove_worker_tx
,
tracing
::
info!
(
)
"DISCOVERY: Worker {worker_id} added, dumping local indexer into router"
.await
;
);
// Query worker's local indexer and dump all events
match
recover_from_worker
(
&
worker_query_client
,
worker_id
,
None
,
// Start from beginning
None
,
// Get all events
&
kv_events_tx
,
)
.await
{
Ok
(
count
)
=>
{
tracing
::
info!
(
"Successfully dumped worker {worker_id}'s local indexer, recovered {count} events"
);
}
Err
(
e
)
=>
{
tracing
::
warn!
(
"Failed to dump worker {worker_id}'s local indexer (may not have local indexer enabled): {e}"
);
}
}
}
DiscoveryEvent
::
Removed
(
worker_id
)
=>
{
tracing
::
warn!
(
"DISCOVERY: Worker {worker_id} removed, removing from router indexer"
);
if
let
Err
(
e
)
=
remove_worker_tx
.send
(
worker_id
)
.await
{
tracing
::
warn!
(
"Failed to send worker removal for worker {worker_id}: {e}"
);
}
}
}
}
}
// Handle event consumption from NATS Core subscription
// Handle event consumption from NATS Core subscription
...
...
lib/llm/src/local_model.rs
View file @
0980b27f
...
@@ -10,9 +10,10 @@ use dynamo_runtime::discovery::DiscoverySpec;
...
@@ -10,9 +10,10 @@ use dynamo_runtime::discovery::DiscoverySpec;
use
dynamo_runtime
::
protocols
::
EndpointId
;
use
dynamo_runtime
::
protocols
::
EndpointId
;
use
dynamo_runtime
::
slug
::
Slug
;
use
dynamo_runtime
::
slug
::
Slug
;
use
dynamo_runtime
::
traits
::
DistributedRuntimeProvider
;
use
dynamo_runtime
::
traits
::
DistributedRuntimeProvider
;
use
dynamo_runtime
::
utils
::
get_http_rpc_host_from_env
;
use
crate
::
entrypoint
::
RouterConfig
;
use
crate
::
entrypoint
::
RouterConfig
;
use
crate
::
mocker
::
protocols
::
MockEngineArgs
;
use
crate
::
mocker
::
protocols
::
{
MockEngineArgs
,
WorkerType
}
;
use
crate
::
model_card
::
ModelDeploymentCard
;
use
crate
::
model_card
::
ModelDeploymentCard
;
use
crate
::
model_type
::{
ModelInput
,
ModelType
};
use
crate
::
model_type
::{
ModelInput
,
ModelType
};
use
crate
::
preprocessor
::
media
::{
ImageDecoder
,
MediaDecoder
,
MediaFetcher
};
use
crate
::
preprocessor
::
media
::{
ImageDecoder
,
MediaDecoder
,
MediaFetcher
};
...
@@ -249,6 +250,22 @@ impl LocalModelBuilder {
...
@@ -249,6 +250,22 @@ impl LocalModelBuilder {
video
:
None
,
video
:
None
,
});
});
self
.media_fetcher
=
Some
(
MediaFetcher
::
default
());
self
.media_fetcher
=
Some
(
MediaFetcher
::
default
());
// Set bootstrap endpoint for prefill workers with bootstrap_port configured
if
mocker_engine_args
.worker_type
==
WorkerType
::
Prefill
&&
let
Some
(
port
)
=
mocker_engine_args
.bootstrap_port
{
let
host
=
get_http_rpc_host_from_env
();
self
.runtime_config.disaggregated_endpoint
=
Some
(
runtime_config
::
DisaggregatedEndpoint
{
bootstrap_host
:
Some
(
host
),
bootstrap_port
:
Some
(
port
),
});
tracing
::
info!
(
bootstrap_port
=
port
,
"Mocker prefill worker: publishing bootstrap endpoint to discovery"
);
}
}
}
// frontend and echo engine don't need a path.
// frontend and echo engine don't need a path.
...
...
lib/llm/src/mocker.rs
View file @
0980b27f
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
// SPDX-License-Identifier: Apache-2.0
pub
mod
bootstrap
;
pub
mod
engine
;
pub
mod
engine
;
pub
mod
evictor
;
pub
mod
evictor
;
pub
mod
kv_manager
;
pub
mod
kv_manager
;
...
...
lib/llm/src/mocker/bootstrap.rs
0 → 100644
View file @
0980b27f
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Bootstrap rendezvous for disaggregated mocker testing.
//!
//! Simulates the SGLang disaggregated serving handshake for KV transfer coordination.
//! Either prefill or decode can arrive first; the rendezvous completes when both are ready.
//!
//! - Prefill: calls `complete_room(room_id)` after first token (KV cache ready)
//! - Decode: connects to prefill's bootstrap server, blocks until prefill completes
//!
//! Wire protocol:
//! - Decode -> Prefill: room_id (8 bytes, little-endian u64)
//! - Prefill -> Decode: ACK (1 byte, 0x01) after prefill completes
use
std
::
sync
::
Arc
;
use
std
::
time
::
Duration
;
use
anyhow
::{
Result
,
bail
};
use
dashmap
::
DashMap
;
use
dashmap
::
mapref
::
entry
::
Entry
;
use
tokio
::
io
::{
AsyncReadExt
,
AsyncWriteExt
};
use
tokio
::
net
::{
TcpListener
,
TcpStream
};
use
tokio
::
sync
::
oneshot
;
use
tokio_util
::
sync
::
CancellationToken
;
/// Timeout for bootstrap rendezvous operations.
const
RENDEZVOUS_TIMEOUT
:
Duration
=
Duration
::
from_secs
(
30
);
/// ACK byte sent from server to decode after prefill completes.
const
ACK_BYTE
:
u8
=
0x01
;
/// State for a room in the rendezvous.
struct
RoomState
{
/// True if prefill has completed (KV cache ready)
prefill_completed
:
bool
,
/// Channel to notify decode when prefill completes (if decode is waiting)
decode_waiting
:
Option
<
oneshot
::
Sender
<
()
>>
,
}
/// Bootstrap server for prefill mockers.
/// Handles rendezvous between prefill and decode for KV transfer coordination.
pub
struct
BootstrapServer
{
port
:
u16
,
rooms
:
Arc
<
DashMap
<
u64
,
RoomState
>>
,
}
impl
BootstrapServer
{
/// Start the bootstrap server on the specified port.
pub
async
fn
start
(
port
:
u16
,
cancel_token
:
CancellationToken
)
->
Result
<
Arc
<
Self
>>
{
let
listener
=
TcpListener
::
bind
(
format!
(
"0.0.0.0:{port}"
))
.await
?
;
let
actual_port
=
listener
.local_addr
()
?
.port
();
tracing
::
info!
(
"Bootstrap server started on port {actual_port}"
);
let
rooms
:
Arc
<
DashMap
<
u64
,
RoomState
>>
=
Arc
::
new
(
DashMap
::
new
());
let
server
=
Arc
::
new
(
Self
{
port
:
actual_port
,
rooms
:
rooms
.clone
(),
});
// Spawn accept loop
tokio
::
spawn
(
async
move
{
loop
{
tokio
::
select!
{
result
=
listener
.accept
()
=>
{
match
result
{
Ok
((
stream
,
addr
))
=>
{
tracing
::
debug!
(
"Bootstrap: accepted connection from {addr}"
);
let
rooms_clone
=
rooms
.clone
();
tokio
::
spawn
(
async
move
{
if
let
Err
(
e
)
=
Self
::
handle_connection
(
stream
,
rooms_clone
)
.await
{
tracing
::
warn!
(
"Bootstrap: connection error: {e}"
);
}
});
}
Err
(
e
)
=>
{
tracing
::
warn!
(
"Bootstrap: accept failed: {e}"
);
}
}
}
_
=
cancel_token
.cancelled
()
=>
{
tracing
::
debug!
(
"Bootstrap server shutting down"
);
break
;
}
}
}
});
Ok
(
server
)
}
/// Handle a connection from decode. Blocks until prefill completes for this room.
async
fn
handle_connection
(
mut
stream
:
TcpStream
,
rooms
:
Arc
<
DashMap
<
u64
,
RoomState
>>
,
)
->
Result
<
()
>
{
// Read room_id (8 bytes, little-endian)
let
mut
buf
=
[
0u8
;
8
];
stream
.read_exact
(
&
mut
buf
)
.await
?
;
let
room_id
=
u64
::
from_le_bytes
(
buf
);
tracing
::
debug!
(
"Bootstrap: decode connected for room {room_id}"
);
// Check room state and wait if needed
let
rx
=
match
rooms
.entry
(
room_id
)
{
Entry
::
Occupied
(
mut
entry
)
=>
{
if
entry
.get
()
.prefill_completed
{
// Prefill already done, immediate ACK
entry
.remove
();
tracing
::
debug!
(
"Bootstrap: room {room_id} already completed, immediate ACK"
);
None
}
else
{
// Prefill registered but not completed, wait
let
(
tx
,
rx
)
=
oneshot
::
channel
();
entry
.get_mut
()
.decode_waiting
=
Some
(
tx
);
tracing
::
debug!
(
"Bootstrap: room {room_id} waiting for prefill to complete"
);
Some
(
rx
)
}
}
Entry
::
Vacant
(
entry
)
=>
{
// Decode arrived first, create entry and wait
let
(
tx
,
rx
)
=
oneshot
::
channel
();
entry
.insert
(
RoomState
{
prefill_completed
:
false
,
decode_waiting
:
Some
(
tx
),
});
tracing
::
debug!
(
"Bootstrap: room {room_id} decode arrived first, waiting"
);
Some
(
rx
)
}
};
// Wait for prefill to complete if needed
if
let
Some
(
rx
)
=
rx
{
match
tokio
::
time
::
timeout
(
RENDEZVOUS_TIMEOUT
,
rx
)
.await
{
Ok
(
Ok
(()))
=>
{
tracing
::
debug!
(
"Bootstrap: room {room_id} prefill completed, sending ACK"
);
}
Ok
(
Err
(
_
))
=>
{
bail!
(
"Bootstrap: room {room_id} sender dropped"
);
}
Err
(
_
)
=>
{
rooms
.remove
(
&
room_id
);
bail!
(
"Bootstrap: room {room_id} timeout waiting for prefill"
);
}
}
}
// Send ACK
stream
.write_all
(
&
[
ACK_BYTE
])
.await
?
;
Ok
(())
}
/// Mark a room as completed (prefill finished, KV cache ready).
/// If decode is already waiting, unblocks it.
pub
fn
complete_room
(
&
self
,
room_id
:
u64
)
{
match
self
.rooms
.entry
(
room_id
)
{
Entry
::
Occupied
(
mut
entry
)
=>
{
if
let
Some
(
sender
)
=
entry
.get_mut
()
.decode_waiting
.take
()
{
// Decode is waiting, unblock it
let
_
=
sender
.send
(());
entry
.remove
();
tracing
::
debug!
(
"Bootstrap: room {room_id} completed, decode unblocked"
);
}
else
{
// Decode not connected yet, mark completed
entry
.get_mut
()
.prefill_completed
=
true
;
tracing
::
debug!
(
"Bootstrap: room {room_id} completed, awaiting decode"
);
}
}
Entry
::
Vacant
(
entry
)
=>
{
// Decode hasn't connected yet
entry
.insert
(
RoomState
{
prefill_completed
:
true
,
decode_waiting
:
None
,
});
tracing
::
debug!
(
"Bootstrap: room {room_id} completed (no decode yet)"
);
}
}
}
/// Get the port the server is listening on.
pub
fn
port
(
&
self
)
->
u16
{
self
.port
}
}
/// Connect to a prefill worker's bootstrap server and wait for KV to be ready.
pub
async
fn
connect_to_prefill
(
host
:
&
str
,
port
:
u16
,
room_id
:
u64
)
->
Result
<
()
>
{
let
host
=
host
.trim_matches
(|
c
|
c
==
'['
||
c
==
']'
);
let
addr
=
format!
(
"{host}:{port}"
);
tracing
::
debug!
(
"Bootstrap: decode connecting to {addr} for room {room_id}"
);
// Connect with timeout
let
mut
stream
=
tokio
::
time
::
timeout
(
RENDEZVOUS_TIMEOUT
,
TcpStream
::
connect
(
&
addr
))
.await
.map_err
(|
_
|
anyhow
::
anyhow!
(
"Bootstrap: connect timeout to {addr}"
))
?
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Bootstrap: connect failed to {addr}: {e}"
))
?
;
// Send room_id
stream
.write_all
(
&
room_id
.to_le_bytes
())
.await
?
;
// Wait for ACK (blocks until prefill completes)
let
mut
ack
=
[
0u8
;
1
];
tokio
::
time
::
timeout
(
RENDEZVOUS_TIMEOUT
,
stream
.read_exact
(
&
mut
ack
))
.await
.map_err
(|
_
|
anyhow
::
anyhow!
(
"Bootstrap: ACK timeout for room {room_id}"
))
?
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Bootstrap: read ACK failed: {e}"
))
?
;
if
ack
[
0
]
!=
ACK_BYTE
{
bail!
(
"Bootstrap: invalid ACK byte {:02x} for room {room_id}"
,
ack
[
0
]
);
}
tracing
::
debug!
(
"Bootstrap: decode received ACK for room {room_id}"
);
Ok
(())
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[tokio::test]
async
fn
test_prefill_completes_first
()
{
let
cancel_token
=
CancellationToken
::
new
();
let
server
=
BootstrapServer
::
start
(
0
,
cancel_token
.clone
())
.await
.unwrap
();
let
port
=
server
.port
();
let
room_id
=
1001u64
;
// Prefill completes first
server
.complete_room
(
room_id
);
// Decode connects - should get immediate ACK
let
result
=
connect_to_prefill
(
"127.0.0.1"
,
port
,
room_id
)
.await
;
assert
!
(
result
.is_ok
(),
"Decode should succeed: {result:?}"
);
cancel_token
.cancel
();
}
#[tokio::test]
async
fn
test_decode_connects_first
()
{
let
cancel_token
=
CancellationToken
::
new
();
let
server
=
BootstrapServer
::
start
(
0
,
cancel_token
.clone
())
.await
.unwrap
();
let
port
=
server
.port
();
let
room_id
=
1002u64
;
// Spawn decode (will block waiting for prefill)
let
decode_handle
=
tokio
::
spawn
(
async
move
{
connect_to_prefill
(
"127.0.0.1"
,
port
,
room_id
)
.await
});
// Give decode time to connect and register
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
50
))
.await
;
// Prefill completes - should unblock decode
server
.complete_room
(
room_id
);
let
result
=
decode_handle
.await
.unwrap
();
assert
!
(
result
.is_ok
(),
"Decode should succeed: {result:?}"
);
cancel_token
.cancel
();
}
#[tokio::test]
async
fn
test_interleaved_ordering
()
{
let
cancel_token
=
CancellationToken
::
new
();
let
server
=
BootstrapServer
::
start
(
0
,
cancel_token
.clone
())
.await
.unwrap
();
let
port
=
server
.port
();
let
room_id
=
1003u64
;
// Spawn decode
let
server_clone
=
server
.clone
();
let
decode_handle
=
tokio
::
spawn
(
async
move
{
// Small delay so prefill can "register" conceptually first
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
10
))
.await
;
connect_to_prefill
(
"127.0.0.1"
,
port
,
room_id
)
.await
});
// Prefill completes after decode starts connecting
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
50
))
.await
;
server_clone
.complete_room
(
room_id
);
let
result
=
decode_handle
.await
.unwrap
();
assert
!
(
result
.is_ok
(),
"Decode should succeed: {result:?}"
);
cancel_token
.cancel
();
}
#[tokio::test]
async
fn
test_multiple_rooms_concurrent
()
{
let
cancel_token
=
CancellationToken
::
new
();
let
server
=
BootstrapServer
::
start
(
0
,
cancel_token
.clone
())
.await
.unwrap
();
let
port
=
server
.port
();
let
mut
handles
=
vec!
[];
// Room 1: prefill first
let
server1
=
server
.clone
();
handles
.push
(
tokio
::
spawn
(
async
move
{
server1
.complete_room
(
2001
);
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
10
))
.await
;
connect_to_prefill
(
"127.0.0.1"
,
port
,
2001
)
.await
}));
// Room 2: decode first
let
server2
=
server
.clone
();
handles
.push
(
tokio
::
spawn
(
async
move
{
let
decode
=
tokio
::
spawn
(
connect_to_prefill
(
"127.0.0.1"
,
port
,
2002
));
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
50
))
.await
;
server2
.complete_room
(
2002
);
decode
.await
.unwrap
()
}));
// Room 3: simultaneous
let
server3
=
server
.clone
();
handles
.push
(
tokio
::
spawn
(
async
move
{
let
decode
=
tokio
::
spawn
(
connect_to_prefill
(
"127.0.0.1"
,
port
,
2003
));
server3
.complete_room
(
2003
);
decode
.await
.unwrap
()
}));
for
(
i
,
handle
)
in
handles
.into_iter
()
.enumerate
()
{
let
result
=
handle
.await
.unwrap
();
assert
!
(
result
.is_ok
(),
"Room {} should succeed: {result:?}"
,
2001
+
i
);
}
cancel_token
.cancel
();
}
#[tokio::test]
async
fn
test_decode_timeout_no_prefill
()
{
let
cancel_token
=
CancellationToken
::
new
();
let
server
=
BootstrapServer
::
start
(
0
,
cancel_token
.clone
())
.await
.unwrap
();
let
port
=
server
.port
();
let
room_id
=
9999u64
;
// Decode connects but prefill never completes - use short timeout
let
result
=
tokio
::
time
::
timeout
(
Duration
::
from_millis
(
100
),
connect_to_prefill
(
"127.0.0.1"
,
port
,
room_id
),
)
.await
;
// Should timeout (outer timeout, not inner RENDEZVOUS_TIMEOUT)
assert
!
(
result
.is_err
(),
"Should timeout waiting for prefill"
);
cancel_token
.cancel
();
}
}
lib/llm/src/mocker/engine.rs
View file @
0980b27f
...
@@ -28,6 +28,7 @@ use dynamo_runtime::{
...
@@ -28,6 +28,7 @@ use dynamo_runtime::{
};
};
use
crate
::
kv_router
::
publisher
::
WorkerMetricsPublisher
;
use
crate
::
kv_router
::
publisher
::
WorkerMetricsPublisher
;
use
crate
::
mocker
::
bootstrap
::{
BootstrapServer
,
connect_to_prefill
};
use
crate
::
mocker
::
protocols
::
DirectRequest
;
use
crate
::
mocker
::
protocols
::
DirectRequest
;
use
crate
::
mocker
::
protocols
::{
MockEngineArgs
,
OutputSignal
,
WorkerType
};
use
crate
::
mocker
::
protocols
::{
MockEngineArgs
,
OutputSignal
,
WorkerType
};
use
crate
::
mocker
::
scheduler
::
Scheduler
;
use
crate
::
mocker
::
scheduler
::
Scheduler
;
...
@@ -47,6 +48,8 @@ pub struct MockVllmEngine {
...
@@ -47,6 +48,8 @@ pub struct MockVllmEngine {
active_requests
:
Arc
<
Mutex
<
HashMap
<
Uuid
,
mpsc
::
UnboundedSender
<
OutputSignal
>>>>
,
active_requests
:
Arc
<
Mutex
<
HashMap
<
Uuid
,
mpsc
::
UnboundedSender
<
OutputSignal
>>>>
,
request_senders
:
Arc
<
OnceCell
<
Vec
<
mpsc
::
UnboundedSender
<
DirectRequest
>>>>
,
request_senders
:
Arc
<
OnceCell
<
Vec
<
mpsc
::
UnboundedSender
<
DirectRequest
>>>>
,
engine_args
:
MockEngineArgs
,
engine_args
:
MockEngineArgs
,
/// Bootstrap server for prefill workers in disaggregated mode
bootstrap_server
:
Arc
<
OnceCell
<
Arc
<
BootstrapServer
>>>
,
}
}
impl
MockVllmEngine
{
impl
MockVllmEngine
{
...
@@ -56,6 +59,7 @@ impl MockVllmEngine {
...
@@ -56,6 +59,7 @@ impl MockVllmEngine {
active_requests
:
Arc
::
new
(
Mutex
::
new
(
HashMap
::
new
())),
active_requests
:
Arc
::
new
(
Mutex
::
new
(
HashMap
::
new
())),
request_senders
:
Arc
::
new
(
OnceCell
::
new
()),
request_senders
:
Arc
::
new
(
OnceCell
::
new
()),
engine_args
:
args
,
engine_args
:
args
,
bootstrap_server
:
Arc
::
new
(
OnceCell
::
new
()),
}
}
}
}
...
@@ -73,6 +77,15 @@ impl MockVllmEngine {
...
@@ -73,6 +77,15 @@ impl MockVllmEngine {
tracing
::
info!
(
"Engine startup simulation completed"
);
tracing
::
info!
(
"Engine startup simulation completed"
);
}
}
// Start bootstrap server for prefill workers in disaggregated mode
if
self
.engine_args.worker_type
==
WorkerType
::
Prefill
&&
let
Some
(
port
)
=
self
.engine_args.bootstrap_port
{
let
server
=
BootstrapServer
::
start
(
port
,
cancel_token
.clone
())
.await
?
;
let
_
=
self
.bootstrap_server
.set
(
server
);
tracing
::
info!
(
port
=
port
,
"Bootstrap server started for prefill worker"
);
}
// Pass component to schedulers only if prefix caching is enabled and not a decode worker
// Pass component to schedulers only if prefix caching is enabled and not a decode worker
let
scheduler_component
=
if
self
.engine_args.enable_prefix_caching
let
scheduler_component
=
if
self
.engine_args.enable_prefix_caching
&&
self
.engine_args.worker_type
!=
WorkerType
::
Decode
&&
self
.engine_args.worker_type
!=
WorkerType
::
Decode
...
@@ -253,6 +266,22 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
...
@@ -253,6 +266,22 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
)));
)));
}
}
// Bootstrap rendezvous for disaggregated serving
// - Decode: connect to prefill's server, block until prefill completes
// - Prefill: complete_room() is called after first token (see below)
let
bootstrap_room
=
request
.bootstrap_info
.as_ref
()
.map
(|
b
|
b
.bootstrap_room
);
if
let
Some
(
bootstrap_info
)
=
&
request
.bootstrap_info
&&
self
.engine_args.worker_type
==
WorkerType
::
Decode
{
connect_to_prefill
(
&
bootstrap_info
.bootstrap_host
,
bootstrap_info
.bootstrap_port
,
bootstrap_info
.bootstrap_room
,
)
.await
.map_err
(|
e
|
Error
::
msg
(
format!
(
"Bootstrap connection failed: {e}"
)))
?
;
}
let
request_uuid
=
ctx
.id
()
.parse
()
.unwrap_or
(
Uuid
::
new_v4
());
let
request_uuid
=
ctx
.id
()
.parse
()
.unwrap_or
(
Uuid
::
new_v4
());
// For prefill workers, override max_tokens to 1
// For prefill workers, override max_tokens to 1
...
@@ -288,6 +317,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
...
@@ -288,6 +317,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
let
active_requests
=
self
.active_requests
.clone
();
let
active_requests
=
self
.active_requests
.clone
();
let
async_context
=
ctx
.context
();
let
async_context
=
ctx
.context
();
let
bootstrap_server
=
self
.bootstrap_server
.clone
();
// Spawn a task to handle the complex async logic
// Spawn a task to handle the complex async logic
tokio
::
spawn
(
async
move
{
tokio
::
spawn
(
async
move
{
...
@@ -325,6 +355,14 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
...
@@ -325,6 +355,14 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
completion_usage
:
None
,
completion_usage
:
None
,
};
};
// Prefill: after first token, mark room complete (unblocks decode)
if
is_prefill
&&
token_count
==
1
&&
let
(
Some
(
server
),
Some
(
room_id
))
=
(
bootstrap_server
.get
(),
bootstrap_room
)
{
server
.complete_room
(
room_id
);
}
if
signal
.completed
&&
token_count
<
max_output_tokens
{
if
signal
.completed
&&
token_count
<
max_output_tokens
{
let
_
=
stream_tx
.send
(
LLMEngineOutput
::
error
(
"Completion signal received before max tokens reached"
.to_string
()));
let
_
=
stream_tx
.send
(
LLMEngineOutput
::
error
(
"Completion signal received before max tokens reached"
.to_string
()));
break
;
break
;
...
...
lib/llm/src/mocker/protocols.rs
View file @
0980b27f
...
@@ -124,6 +124,12 @@ pub struct MockEngineArgs {
...
@@ -124,6 +124,12 @@ pub struct MockEngineArgs {
/// Enable worker-local KV indexer for tracking this worker's own KV cache state
/// Enable worker-local KV indexer for tracking this worker's own KV cache state
#[builder(default
=
"false"
)]
#[builder(default
=
"false"
)]
pub
enable_local_indexer
:
bool
,
pub
enable_local_indexer
:
bool
,
/// Bootstrap port for disaggregated serving rendezvous.
/// Prefill workers listen on this port; decode workers connect to it.
/// If None, bootstrap rendezvous is disabled.
#[builder(default
=
"None"
)]
pub
bootstrap_port
:
Option
<
u16
>
,
}
}
impl
Default
for
MockEngineArgs
{
impl
Default
for
MockEngineArgs
{
...
@@ -163,6 +169,7 @@ impl MockEngineArgs {
...
@@ -163,6 +169,7 @@ impl MockEngineArgs {
"is_decode"
,
"is_decode"
,
"planner_profile_data"
,
"planner_profile_data"
,
"enable_local_indexer"
,
"enable_local_indexer"
,
"bootstrap_port"
,
]
]
.iter
()
.iter
()
.cloned
()
.cloned
()
...
@@ -250,6 +257,12 @@ impl MockEngineArgs {
...
@@ -250,6 +257,12 @@ impl MockEngineArgs {
builder
=
builder
.enable_local_indexer
(
enabled
);
builder
=
builder
.enable_local_indexer
(
enabled
);
}
}
if
let
Some
(
value
)
=
extra_args
.get
(
"bootstrap_port"
)
&&
let
Some
(
port
)
=
value
.as_u64
()
{
builder
=
builder
.bootstrap_port
(
Some
(
port
as
u16
));
}
// Parse worker type from is_prefill and is_decode flags
// Parse worker type from is_prefill and is_decode flags
let
is_prefill
=
extra_args
let
is_prefill
=
extra_args
.get
(
"is_prefill"
)
.get
(
"is_prefill"
)
...
...
lib/llm/src/preprocessor.rs
View file @
0980b27f
...
@@ -245,6 +245,7 @@ impl OpenAIPreprocessor {
...
@@ -245,6 +245,7 @@ impl OpenAIPreprocessor {
prefill_worker_id
:
nvext
.prefill_worker_id
,
prefill_worker_id
:
nvext
.prefill_worker_id
,
decode_worker_id
:
nvext
.decode_worker_id
,
decode_worker_id
:
nvext
.decode_worker_id
,
dp_rank
:
None
,
// dp_rank is set later in the pipeline
dp_rank
:
None
,
// dp_rank is set later in the pipeline
enable_local_updates
:
nvext
.enable_local_updates
,
};
};
builder
.routing
(
Some
(
routing
));
builder
.routing
(
Some
(
routing
));
}
}
...
...
lib/llm/src/protocols/common/preprocessor.rs
View file @
0980b27f
...
@@ -34,6 +34,14 @@ pub struct RoutingHints {
...
@@ -34,6 +34,14 @@ pub struct RoutingHints {
/// Data parallel rank for the request
/// Data parallel rank for the request
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
dp_rank
:
Option
<
u32
>
,
pub
dp_rank
:
Option
<
u32
>
,
/// Controls whether the router should manage local bookkeeping (add_request,
/// mark_prefill_completed, free) for this request.
///
/// - `None` or `Some(true)`: Router handles bookkeeping locally (default behavior)
/// - `Some(false)`: External caller (e.g., GAIE sidecar) handles bookkeeping via C FFI
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
enable_local_updates
:
Option
<
bool
>
,
}
}
#[derive(Serialize,
Deserialize,
Debug,
Clone,
Default)]
#[derive(Serialize,
Deserialize,
Debug,
Clone,
Default)]
...
...
lib/llm/src/protocols/common/timing.rs
View file @
0980b27f
...
@@ -6,9 +6,12 @@
...
@@ -6,9 +6,12 @@
//! This module provides [`RequestTracker`] for tracking timing and routing information
//! This module provides [`RequestTracker`] for tracking timing and routing information
//! that can be returned to clients via the `nvext` response field.
//! that can be returned to clients via the `nvext` response field.
use
serde
::{
Deserialize
,
Serialize
};
use
std
::
sync
::{
Arc
,
OnceLock
};
use
std
::
sync
::{
Mutex
,
OnceLock
};
use
std
::
time
::{
Instant
,
SystemTime
,
UNIX_EPOCH
};
use
std
::
time
::{
Instant
,
SystemTime
,
UNIX_EPOCH
};
use
parking_lot
::
Mutex
;
use
serde
::{
Deserialize
,
Serialize
};
use
tokio
::
sync
::{
OwnedSemaphorePermit
,
Semaphore
};
use
utoipa
::
ToSchema
;
use
utoipa
::
ToSchema
;
use
crate
::
protocols
::
openai
::
nvext
::
WorkerIdInfo
;
use
crate
::
protocols
::
openai
::
nvext
::
WorkerIdInfo
;
...
@@ -80,6 +83,12 @@ pub struct RequestTracker {
...
@@ -80,6 +83,12 @@ pub struct RequestTracker {
/// Request phase (Prefill/Decode/Aggregated)
/// Request phase (Prefill/Decode/Aggregated)
phase
:
Mutex
<
RequestPhase
>
,
phase
:
Mutex
<
RequestPhase
>
,
/// Semaphore for coordinating phase transitions.
/// Acquiring a permit blocks subsequent set_phase calls until the permit is dropped.
/// This prevents race conditions in the bootstrap optimization path where prefill
/// runs in background and needs to complete record_worker before phase changes.
phase_semaphore
:
Arc
<
Semaphore
>
,
}
}
impl
RequestTracker
{
impl
RequestTracker
{
...
@@ -102,6 +111,7 @@ impl RequestTracker {
...
@@ -102,6 +111,7 @@ impl RequestTracker {
prefill_worker_id
:
OnceLock
::
new
(),
prefill_worker_id
:
OnceLock
::
new
(),
decode_worker_id
:
OnceLock
::
new
(),
decode_worker_id
:
OnceLock
::
new
(),
phase
:
Mutex
::
new
(
RequestPhase
::
Aggregated
),
phase
:
Mutex
::
new
(
RequestPhase
::
Aggregated
),
phase_semaphore
:
Arc
::
new
(
Semaphore
::
new
(
1
)),
}
}
}
}
...
@@ -175,14 +185,29 @@ impl RequestTracker {
...
@@ -175,14 +185,29 @@ impl RequestTracker {
self
.decode_worker_id
.set
(
id
)
.is_ok
()
self
.decode_worker_id
.set
(
id
)
.is_ok
()
}
}
/// Set the request phase. Can be called multiple times to update the phase.
/// Set the request phase and return a permit that blocks subsequent phase changes.
pub
fn
set_phase
(
&
self
,
phase
:
RequestPhase
)
{
///
*
self
.phase
.lock
()
.unwrap
()
=
phase
;
/// The returned permit must be dropped to allow the next `set_phase` call to proceed.
/// Under normal operation, callers can simply ignore the returned permit (letting it
/// drop immediately). In the bootstrap optimization path, the permit is held and
/// passed to the spawned prefill task, which drops it after `record_worker` completes.
///
/// This prevents the race condition where the phase changes to Decode before the
/// background prefill task has recorded its worker ID.
pub
async
fn
set_phase
(
&
self
,
phase
:
RequestPhase
)
->
OwnedSemaphorePermit
{
let
permit
=
self
.phase_semaphore
.clone
()
.acquire_owned
()
.await
.expect
(
"phase semaphore should never be closed"
);
*
self
.phase
.lock
()
=
phase
;
permit
}
}
/// Get the current request phase.
/// Get the current request phase.
pub
fn
phase
(
&
self
)
->
RequestPhase
{
pub
fn
phase
(
&
self
)
->
RequestPhase
{
*
self
.phase
.lock
()
.unwrap
()
*
self
.phase
.lock
()
}
}
/// Record worker ID based on the current phase.
/// Record worker ID based on the current phase.
...
...
lib/llm/src/protocols/openai/nvext.rs
View file @
0980b27f
...
@@ -105,6 +105,17 @@ pub struct NvExt {
...
@@ -105,6 +105,17 @@ pub struct NvExt {
#[builder(default,
setter(strip_option))]
#[builder(default,
setter(strip_option))]
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
decode_worker_id
:
Option
<
u64
>
,
pub
decode_worker_id
:
Option
<
u64
>
,
/// Controls whether the router should manage local bookkeeping (add_request,
/// mark_prefill_completed, free) for this request.
///
/// - `None` or `true`: Router handles bookkeeping locally (default behavior)
/// - `false`: External caller (e.g., GAIE sidecar) handles bookkeeping via C FFI
///
/// Set to `false` for GAIE Stage 2 when the EPP/sidecar manages request lifecycle.
#[builder(default,
setter(strip_option))]
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
enable_local_updates
:
Option
<
bool
>
,
}
}
impl
Default
for
NvExt
{
impl
Default
for
NvExt
{
...
@@ -153,6 +164,7 @@ mod tests {
...
@@ -153,6 +164,7 @@ mod tests {
assert_eq!
(
nv_ext
.extra_fields
,
None
);
assert_eq!
(
nv_ext
.extra_fields
,
None
);
assert_eq!
(
nv_ext
.prefill_worker_id
,
None
);
assert_eq!
(
nv_ext
.prefill_worker_id
,
None
);
assert_eq!
(
nv_ext
.decode_worker_id
,
None
);
assert_eq!
(
nv_ext
.decode_worker_id
,
None
);
assert_eq!
(
nv_ext
.enable_local_updates
,
None
);
}
}
// Test valid builder configurations
// Test valid builder configurations
...
...
lib/runtime/src/component/endpoint.rs
View file @
0980b27f
...
@@ -230,7 +230,7 @@ impl EndpointConfigBuilder {
...
@@ -230,7 +230,7 @@ impl EndpointConfigBuilder {
/// This function handles both health check and discovery transport building.
/// This function handles both health check and discovery transport building.
/// All transport modes use consistent addressing:
/// All transport modes use consistent addressing:
/// - HTTP: Uses full URL path including endpoint name (e.g., http://host:port/v1/rpc/endpoint_name)
/// - HTTP: Uses full URL path including endpoint name (e.g., http://host:port/v1/rpc/endpoint_name)
/// - TCP: Includes endpoint name for routing (e.g., host:port/endpoint_name)
/// - TCP: Includes
instance_id and
endpoint name for routing (e.g., host:port/
instance_id_hex/
endpoint_name)
/// - NATS: Uses subject-based addressing (unique per endpoint)
/// - NATS: Uses subject-based addressing (unique per endpoint)
///
///
/// # Errors
/// # Errors
...
@@ -266,9 +266,14 @@ fn build_transport_type_inner(
...
@@ -266,9 +266,14 @@ fn build_transport_type_inner(
.and_then
(|
p
|
p
.parse
::
<
u16
>
()
.ok
())
.and_then
(|
p
|
p
.parse
::
<
u16
>
()
.ok
())
.unwrap_or
(
crate
::
pipeline
::
network
::
manager
::
get_actual_tcp_rpc_port
()
?
);
.unwrap_or
(
crate
::
pipeline
::
network
::
manager
::
get_actual_tcp_rpc_port
()
?
);
// Include endpoint name for proper TCP routing
// Include instance_id and endpoint name for proper TCP routing.
// TCP client parses this format and adds x-endpoint-path header for server-side routing
// Format: host:port/instance_id_hex/endpoint_name
let
tcp_endpoint
=
format!
(
"{}:{}/{}"
,
tcp_host
,
tcp_port
,
endpoint_id
.name
);
// This ensures each worker has a unique routing key when multiple workers
// share the same TCP server (e.g., --num-workers > 1).
let
tcp_endpoint
=
format!
(
"{}:{}/{:x}/{}"
,
tcp_host
,
tcp_port
,
connection_id
,
endpoint_id
.name
);
Ok
(
TransportType
::
Tcp
(
tcp_endpoint
))
Ok
(
TransportType
::
Tcp
(
tcp_endpoint
))
}
}
...
...
lib/runtime/src/pipeline/network/ingress/shared_tcp_endpoint.rs
View file @
0980b27f
...
@@ -413,9 +413,11 @@ impl super::unified_server::RequestPlaneServer for SharedTcpServer {
...
@@ -413,9 +413,11 @@ impl super::unified_server::RequestPlaneServer for SharedTcpServer {
component_name
:
String
,
component_name
:
String
,
system_health
:
Arc
<
Mutex
<
SystemHealth
>>
,
system_health
:
Arc
<
Mutex
<
SystemHealth
>>
,
)
->
Result
<
()
>
{
)
->
Result
<
()
>
{
// For TCP, we use endpoint_name as both the endpoint_path (routing key) and endpoint_name
// Include instance_id in the routing key to avoid collisions when multiple workers
// share the same TCP server (e.g., --num-workers > 1 in tests)
let
endpoint_path
=
format!
(
"{instance_id:x}/{endpoint_name}"
);
self
.register_endpoint
(
self
.register_endpoint
(
endpoint_
name
.clone
()
,
endpoint_
path
,
service_handler
,
service_handler
,
instance_id
,
instance_id
,
namespace
,
namespace
,
...
@@ -427,7 +429,19 @@ impl super::unified_server::RequestPlaneServer for SharedTcpServer {
...
@@ -427,7 +429,19 @@ impl super::unified_server::RequestPlaneServer for SharedTcpServer {
}
}
async
fn
unregister_endpoint
(
&
self
,
endpoint_name
:
&
str
)
->
Result
<
()
>
{
async
fn
unregister_endpoint
(
&
self
,
endpoint_name
:
&
str
)
->
Result
<
()
>
{
self
.unregister_endpoint
(
endpoint_name
,
endpoint_name
)
.await
;
// With multiple workers per process, each registers with a unique key
// "{instance_id}/{endpoint_name}". Find and remove all matching entries.
let
suffix
=
format!
(
"/{endpoint_name}"
);
let
keys_to_remove
:
Vec
<
String
>
=
self
.handlers
.iter
()
.filter
(|
entry
|
entry
.key
()
.ends_with
(
&
suffix
))
.map
(|
entry
|
entry
.key
()
.clone
())
.collect
();
for
key
in
keys_to_remove
{
self
.unregister_endpoint
(
&
key
,
endpoint_name
)
.await
;
}
Ok
(())
Ok
(())
}
}
...
...
lib/runtime/src/pipeline/network/manager.rs
View file @
0980b27f
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
//! directly accesses transport implementations or configuration.
//! directly accesses transport implementations or configuration.
use
super
::
egress
::
unified_client
::
RequestPlaneClient
;
use
super
::
egress
::
unified_client
::
RequestPlaneClient
;
use
super
::
ingress
::
shared_tcp_endpoint
::
SharedTcpServer
;
use
super
::
ingress
::
unified_server
::
RequestPlaneServer
;
use
super
::
ingress
::
unified_server
::
RequestPlaneServer
;
use
crate
::
distributed
::
RequestPlaneMode
;
use
crate
::
distributed
::
RequestPlaneMode
;
use
anyhow
::
Result
;
use
anyhow
::
Result
;
...
@@ -26,6 +27,17 @@ use tokio_util::sync::CancellationToken;
...
@@ -26,6 +27,17 @@ use tokio_util::sync::CancellationToken;
/// Uses OnceLock since the port is set once when the server binds and never changes.
/// Uses OnceLock since the port is set once when the server binds and never changes.
static
ACTUAL_TCP_RPC_PORT
:
OnceLock
<
u16
>
=
OnceLock
::
new
();
static
ACTUAL_TCP_RPC_PORT
:
OnceLock
<
u16
>
=
OnceLock
::
new
();
/// Global storage for the shared TCP server instance.
///
/// When multiple workers run in the same process, they must share a single TCP server
/// to ensure all endpoints are registered on the same server. Without this, each worker
/// would create its own server on a different port, but all would publish the same port
/// (from ACTUAL_TCP_RPC_PORT) to discovery, causing "No handler found" errors.
///
/// Uses `tokio::sync::OnceCell` to support async initialization (binding the TCP socket).
static
GLOBAL_TCP_SERVER
:
tokio
::
sync
::
OnceCell
<
Arc
<
SharedTcpServer
>>
=
tokio
::
sync
::
OnceCell
::
const_new
();
/// Get the actual TCP RPC port that the server is listening on.
/// Get the actual TCP RPC port that the server is listening on.
pub
fn
get_actual_tcp_rpc_port
()
->
anyhow
::
Result
<
u16
>
{
pub
fn
get_actual_tcp_rpc_port
()
->
anyhow
::
Result
<
u16
>
{
ACTUAL_TCP_RPC_PORT
.get
()
.copied
()
.ok_or_else
(||
{
ACTUAL_TCP_RPC_PORT
.get
()
.copied
()
.ok_or_else
(||
{
...
@@ -300,35 +312,41 @@ impl NetworkManager {
...
@@ -300,35 +312,41 @@ impl NetworkManager {
}
}
async
fn
create_tcp_server
(
&
self
)
->
Result
<
Arc
<
dyn
RequestPlaneServer
>>
{
async
fn
create_tcp_server
(
&
self
)
->
Result
<
Arc
<
dyn
RequestPlaneServer
>>
{
use
super
::
ingress
::
shared_tcp_endpoint
::
SharedTcpServer
;
// Use the global TCP server to ensure all workers in the same process share
// a single server. This is critical for correct endpoint routing.
let
server
=
GLOBAL_TCP_SERVER
.get_or_try_init
(||
async
{
// Use configured port if specified, otherwise use port 0 (OS assigns free port)
let
port
=
self
.config.tcp_port
.unwrap_or
(
0
);
let
bind_addr
=
format!
(
"{}:{}"
,
self
.config.tcp_host
,
port
)
.parse
()
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Invalid TCP bind address: {}"
,
e
))
?
;
// Use configured port if specified, otherwise use port 0 (OS assigns free port)
tracing
::
info!
(
let
port
=
self
.config.tcp_port
.unwrap_or
(
0
);
bind_addr
=
%
bind_addr
,
let
bind_addr
=
format!
(
"{}:{}"
,
self
.config.tcp_host
,
port
)
port_source
=
if
self
.config.tcp_port
.is_some
()
{
"DYN_TCP_RPC_PORT"
}
else
{
"OS-assigned"
},
.parse
()
"Creating TCP request plane server"
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Invalid TCP bind address: {}"
,
e
))
?
;
)
;
tracing
::
info!
(
let
server
=
SharedTcpServer
::
new
(
bind_addr
,
self
.cancellation_token
.clone
());
bind_addr
=
%
bind_addr
,
port_source
=
if
self
.config.tcp_port
.is_some
()
{
"DYN_TCP_RPC_PORT"
}
else
{
"OS-assigned"
},
"Creating TCP request plane server"
);
let
server
=
SharedTcpServer
::
new
(
bind_addr
,
self
.cancellation_token
.clone
());
// Bind and start server, getting the actual bound address
let
actual_addr
=
server
.clone
()
.bind_and_start
()
.await
?
;
// Bind and start server, getting the actual bound address
// Store the actual bound port globally so build_transport_type() can access it
l
et
actual_
addr
=
server
.clone
()
.bind_and_start
()
.await
?
;
s
et
_
actual_
tcp_rpc_port
(
actual_addr
.port
())
;
// Store the actual bound port globally so build_transport_type() can access it
tracing
::
info!
(
set_actual_tcp_rpc_port
(
actual_addr
.port
());
actual_addr
=
%
actual_addr
,
actual_port
=
actual_addr
.port
(),
"TCP request plane server started"
);
tracing
::
info!
(
Ok
::
<
_
,
anyhow
::
Error
>
(
server
)
actual_addr
=
%
actual_addr
,
})
actual_port
=
actual_addr
.port
(),
.await
?
;
"TCP request plane server started"
);
Ok
(
server
as
Arc
<
dyn
RequestPlaneServer
>
)
Ok
(
server
.clone
()
as
Arc
<
dyn
RequestPlaneServer
>
)
}
}
async
fn
create_nats_server
(
&
self
)
->
Result
<
Arc
<
dyn
RequestPlaneServer
>>
{
async
fn
create_nats_server
(
&
self
)
->
Result
<
Arc
<
dyn
RequestPlaneServer
>>
{
...
...
tests/router/test_router_e2e_with_mockers.py
View file @
0980b27f
...
@@ -6,6 +6,12 @@
...
@@ -6,6 +6,12 @@
# Combined pre_merge wall time (this file):
# Combined pre_merge wall time (this file):
# - Serialized: 304.01s.
# - Serialized: 304.01s.
# - Parallel (-n auto): 34.55s (269.46s saved, 8.80x).
# - Parallel (-n auto): 34.55s (269.46s saved, 8.80x).
#
# NOTE: TCP request plane is NOT tested here. These tests use --num-workers > 1 which spawns
# multiple workers in a single process sharing one TCP server. The shared TCP server uses
# endpoint_path (e.g., "generate") as the routing key, causing handler collisions when multiple
# workers register the same endpoint. This is a test-only limitation; production deployments
# with separate processes per worker work correctly with TCP.
import
logging
import
logging
import
os
import
os
from
typing
import
Any
,
Dict
,
Optional
from
typing
import
Any
,
Dict
,
Optional
...
@@ -155,6 +161,8 @@ def _build_mocker_command(
...
@@ -155,6 +161,8 @@ def _build_mocker_command(
command
.
extend
([
"--data-parallel-size"
,
str
(
mocker_args
[
"dp_size"
])])
command
.
extend
([
"--data-parallel-size"
,
str
(
mocker_args
[
"dp_size"
])])
if
mocker_args
.
get
(
"enable_local_indexer"
):
if
mocker_args
.
get
(
"enable_local_indexer"
):
command
.
append
(
"--enable-local-indexer"
)
command
.
append
(
"--enable-local-indexer"
)
if
"bootstrap_ports"
in
mocker_args
:
command
.
extend
([
"--bootstrap-ports"
,
mocker_args
[
"bootstrap_ports"
]])
return
command
return
command
...
@@ -233,6 +241,7 @@ class DisaggMockerProcess:
...
@@ -233,6 +241,7 @@ class DisaggMockerProcess:
num_mockers
:
int
=
1
,
num_mockers
:
int
=
1
,
store_backend
:
str
=
"etcd"
,
store_backend
:
str
=
"etcd"
,
request_plane
:
str
=
"nats"
,
request_plane
:
str
=
"nats"
,
enable_bootstrap
:
bool
=
False
,
):
):
if
worker_type
not
in
(
"prefill"
,
"decode"
):
if
worker_type
not
in
(
"prefill"
,
"decode"
):
raise
ValueError
(
raise
ValueError
(
...
@@ -242,6 +251,7 @@ class DisaggMockerProcess:
...
@@ -242,6 +251,7 @@ class DisaggMockerProcess:
self
.
namespace
=
namespace
self
.
namespace
=
namespace
self
.
worker_type
=
worker_type
self
.
worker_type
=
worker_type
self
.
num_workers
=
num_mockers
self
.
num_workers
=
num_mockers
self
.
_bootstrap_ports
:
list
[
int
]
=
[]
# Set component name and endpoint based on worker type
# Set component name and endpoint based on worker type
if
worker_type
==
"prefill"
:
if
worker_type
==
"prefill"
:
...
@@ -251,7 +261,17 @@ class DisaggMockerProcess:
...
@@ -251,7 +261,17 @@ class DisaggMockerProcess:
self
.
component_name
=
"backend"
self
.
component_name
=
"backend"
self
.
endpoint
=
f
"dyn://
{
self
.
namespace
}
.backend.generate"
self
.
endpoint
=
f
"dyn://
{
self
.
namespace
}
.backend.generate"
mocker_args
=
mocker_args
or
{}
mocker_args
=
(
mocker_args
or
{}).
copy
()
# Allocate bootstrap ports for prefill workers if enabled (one per worker)
if
enable_bootstrap
and
worker_type
==
"prefill"
:
self
.
_bootstrap_ports
=
allocate_ports
(
num_mockers
,
BASE_PORT
)
mocker_args
[
"bootstrap_ports"
]
=
","
.
join
(
str
(
p
)
for
p
in
self
.
_bootstrap_ports
)
logger
.
info
(
f
"Allocated bootstrap ports
{
self
.
_bootstrap_ports
}
for
{
num_mockers
}
prefill workers"
)
command
=
_build_mocker_command
(
command
=
_build_mocker_command
(
endpoint
=
self
.
endpoint
,
endpoint
=
self
.
endpoint
,
...
@@ -279,6 +299,11 @@ class DisaggMockerProcess:
...
@@ -279,6 +299,11 @@ class DisaggMockerProcess:
f
"endpoint:
{
self
.
endpoint
}
"
f
"endpoint:
{
self
.
endpoint
}
"
)
)
@
property
def
bootstrap_ports
(
self
)
->
list
[
int
]:
"""Return the allocated bootstrap ports, if any."""
return
self
.
_bootstrap_ports
def
__enter__
(
self
):
def
__enter__
(
self
):
logger
.
info
(
logger
.
info
(
f
"Starting
{
self
.
worker_type
}
mocker process with
{
self
.
num_workers
}
worker(s)"
f
"Starting
{
self
.
worker_type
}
mocker process with
{
self
.
num_workers
}
worker(s)"
...
@@ -289,6 +314,11 @@ class DisaggMockerProcess:
...
@@ -289,6 +314,11 @@ class DisaggMockerProcess:
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
logger
.
info
(
f
"Stopping
{
self
.
worker_type
}
mocker process"
)
logger
.
info
(
f
"Stopping
{
self
.
worker_type
}
mocker process"
)
self
.
_process
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
self
.
_process
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
# Deallocate bootstrap ports if we allocated any
if
self
.
_bootstrap_ports
:
deallocate_ports
(
self
.
_bootstrap_ports
)
logger
.
info
(
f
"Deallocated bootstrap ports
{
self
.
_bootstrap_ports
}
"
)
self
.
_bootstrap_ports
=
[]
@
pytest
.
mark
.
timeout
(
42
)
# ~3x average (~13.80s), rounded up
@
pytest
.
mark
.
timeout
(
42
)
# ~3x average (~13.80s), rounded up
...
@@ -487,9 +517,9 @@ def test_kv_push_router_bindings(
...
@@ -487,9 +517,9 @@ def test_kv_push_router_bindings(
],
],
ids
=
[
ids
=
[
"jetstream"
,
"jetstream"
,
"nats"
,
"nats
_core
"
,
"file"
,
"file"
,
],
# "nats_core" commented out to match commented test case
],
)
)
@
pytest
.
mark
.
timeout
(
90
)
# TODO: figure out a timeout
@
pytest
.
mark
.
timeout
(
90
)
# TODO: figure out a timeout
def
test_indexers_sync
(
def
test_indexers_sync
(
...
@@ -677,37 +707,43 @@ def test_router_decisions(
...
@@ -677,37 +707,43 @@ def test_router_decisions(
mockers
.
__exit__
(
None
,
None
,
None
)
mockers
.
__exit__
(
None
,
None
,
None
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"registration_order"
,
[
"prefill_first"
,
"decode_first"
])
@
pytest
.
mark
.
parametrize
(
"registration_order"
,
[
"prefill_first"
,
"decode_first"
])
@
pytest
.
mark
.
parametrize
(
"enable_disagg_bootstrap"
,
[
False
,
True
],
ids
=
[
"no_bootstrap"
,
"with_bootstrap"
]
)
@
pytest
.
mark
.
timeout
(
59
)
# ~3x average (~19.51s), rounded up
@
pytest
.
mark
.
timeout
(
59
)
# ~3x average (~19.51s), rounded up
def
test_router_decisions_disagg
(
def
test_router_decisions_disagg
(
request
,
request
,
runtime_services_dynamic_ports
,
runtime_services_dynamic_ports
,
predownload_tokenizers
,
predownload_tokenizers
,
registration_order
,
registration_order
,
request_plane
,
enable_disagg_bootstrap
,
):
):
"""Validate KV cache prefix reuse in disaggregated prefill-decode setup.
"""Validate KV cache prefix reuse in disaggregated prefill-decode setup.
Tests that progressive requests with overlapping prefixes are routed to the
Tests that progressive requests with overlapping prefixes are routed to the
same prefill worker due to KV cache reuse.
same prefill worker due to KV cache reuse.
Parameterized to test
both registration orders
:
Parameterized to test:
-
p
re
fill_first: prefill w
or
k
er
s register before
decode
worke
rs
- re
gistration_
or
d
er
: prefill_first vs
decode
_fi
rs
t
-
decode_first: decode workers register before prefill worker
s
-
enable_disagg_bootstrap: without vs with bootstrap rendezvou
s
"""
"""
# runtime_services_dynamic_ports handles NATS and etcd startup
# runtime_services_dynamic_ports handles NATS and etcd startup
logger
.
info
(
logger
.
info
(
f
"Starting disaggregated router prefix reuse test "
f
"Starting disaggregated router prefix reuse test "
f
"(registration_order=
{
registration_order
}
)"
f
"(registration_order=
{
registration_order
}
, bootstrap=
{
enable_disagg_bootstrap
}
)"
)
)
# Generate shared namespace for prefill and decode workers
# Generate shared namespace for prefill and decode workers
namespace_suffix
=
generate_random_suffix
()
namespace_suffix
=
generate_random_suffix
()
shared_namespace
=
f
"test-namespace-
{
namespace_suffix
}
"
shared_namespace
=
f
"test-namespace-
{
namespace_suffix
}
"
# Create mocker args
# Create mocker args - use JetStream for KV events (more reliable than NATS Core)
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
}
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
,
"enable_local_indexer"
:
False
,
}
prefill_workers
=
None
prefill_workers
=
None
decode_workers
=
None
decode_workers
=
None
...
@@ -722,7 +758,8 @@ def test_router_decisions_disagg(
...
@@ -722,7 +758,8 @@ def test_router_decisions_disagg(
worker_type
=
"prefill"
,
worker_type
=
"prefill"
,
mocker_args
=
mocker_args
,
mocker_args
=
mocker_args
,
num_mockers
=
4
,
num_mockers
=
4
,
request_plane
=
request_plane
,
request_plane
=
"nats"
,
enable_bootstrap
=
enable_disagg_bootstrap
,
)
)
prefill_workers
.
__enter__
()
prefill_workers
.
__enter__
()
logger
.
info
(
f
"Prefill workers using endpoint:
{
prefill_workers
.
endpoint
}
"
)
logger
.
info
(
f
"Prefill workers using endpoint:
{
prefill_workers
.
endpoint
}
"
)
...
@@ -735,7 +772,7 @@ def test_router_decisions_disagg(
...
@@ -735,7 +772,7 @@ def test_router_decisions_disagg(
worker_type
=
"decode"
,
worker_type
=
"decode"
,
mocker_args
=
mocker_args
,
mocker_args
=
mocker_args
,
num_mockers
=
4
,
num_mockers
=
4
,
request_plane
=
request_plane
,
request_plane
=
"nats"
,
)
)
decode_workers
.
__enter__
()
decode_workers
.
__enter__
()
logger
.
info
(
f
"Decode workers using endpoint:
{
decode_workers
.
endpoint
}
"
)
logger
.
info
(
f
"Decode workers using endpoint:
{
decode_workers
.
endpoint
}
"
)
...
@@ -748,7 +785,7 @@ def test_router_decisions_disagg(
...
@@ -748,7 +785,7 @@ def test_router_decisions_disagg(
worker_type
=
"decode"
,
worker_type
=
"decode"
,
mocker_args
=
mocker_args
,
mocker_args
=
mocker_args
,
num_mockers
=
4
,
num_mockers
=
4
,
request_plane
=
request_plane
,
request_plane
=
"nats"
,
)
)
decode_workers
.
__enter__
()
decode_workers
.
__enter__
()
logger
.
info
(
f
"Decode workers using endpoint:
{
decode_workers
.
endpoint
}
"
)
logger
.
info
(
f
"Decode workers using endpoint:
{
decode_workers
.
endpoint
}
"
)
...
@@ -761,7 +798,8 @@ def test_router_decisions_disagg(
...
@@ -761,7 +798,8 @@ def test_router_decisions_disagg(
worker_type
=
"prefill"
,
worker_type
=
"prefill"
,
mocker_args
=
mocker_args
,
mocker_args
=
mocker_args
,
num_mockers
=
4
,
num_mockers
=
4
,
request_plane
=
request_plane
,
request_plane
=
"nats"
,
enable_bootstrap
=
enable_disagg_bootstrap
,
)
)
prefill_workers
.
__enter__
()
prefill_workers
.
__enter__
()
logger
.
info
(
f
"Prefill workers using endpoint:
{
prefill_workers
.
endpoint
}
"
)
logger
.
info
(
f
"Prefill workers using endpoint:
{
prefill_workers
.
endpoint
}
"
)
...
@@ -779,7 +817,7 @@ def test_router_decisions_disagg(
...
@@ -779,7 +817,7 @@ def test_router_decisions_disagg(
request
=
request
,
request
=
request
,
frontend_port
=
frontend_port
,
frontend_port
=
frontend_port
,
test_payload
=
TEST_PAYLOAD
,
test_payload
=
TEST_PAYLOAD
,
request_plane
=
request_plane
,
request_plane
=
"nats"
,
)
)
finally
:
finally
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment