Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
584020f4
Unverified
Commit
584020f4
authored
Jan 29, 2026
by
Yan Ru Pei
Committed by
GitHub
Jan 29, 2026
Browse files
chore: consistently used the arc shared workers configs + CI flake fixes (#5707)
Signed-off-by:
PeaBrane
<
yanrpei@gmail.com
>
parent
1174c819
Changes
18
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
729 additions
and
404 deletions
+729
-404
lib/llm/src/discovery.rs
lib/llm/src/discovery.rs
+4
-1
lib/llm/src/discovery/model_manager.rs
lib/llm/src/discovery/model_manager.rs
+11
-140
lib/llm/src/discovery/runtime_configs.rs
lib/llm/src/discovery/runtime_configs.rs
+200
-0
lib/llm/src/hub.rs
lib/llm/src/hub.rs
+64
-0
lib/llm/src/kv_router.rs
lib/llm/src/kv_router.rs
+18
-35
lib/llm/src/kv_router/publisher.rs
lib/llm/src/kv_router/publisher.rs
+6
-9
lib/llm/src/kv_router/scheduler.rs
lib/llm/src/kv_router/scheduler.rs
+14
-8
lib/llm/src/kv_router/subscriber.rs
lib/llm/src/kv_router/subscriber.rs
+30
-39
lib/llm/src/kv_router/worker_query.rs
lib/llm/src/kv_router/worker_query.rs
+17
-15
lib/runtime/src/config/environment_names.rs
lib/runtime/src/config/environment_names.rs
+5
-0
lib/runtime/src/transports/nats.rs
lib/runtime/src/transports/nats.rs
+0
-6
tests/conftest.py
tests/conftest.py
+199
-103
tests/router/common.py
tests/router/common.py
+16
-8
tests/router/test_router_e2e_with_mockers.py
tests/router/test_router_e2e_with_mockers.py
+77
-20
tests/router/test_router_e2e_with_sglang.py
tests/router/test_router_e2e_with_sglang.py
+17
-5
tests/router/test_router_e2e_with_trtllm.py
tests/router/test_router_e2e_with_trtllm.py
+16
-4
tests/router/test_router_e2e_with_vllm.py
tests/router/test_router_e2e_with_vllm.py
+16
-5
tests/utils/managed_process.py
tests/utils/managed_process.py
+19
-6
No files found.
lib/llm/src/discovery.rs
View file @
584020f4
...
...
@@ -2,7 +2,10 @@
// SPDX-License-Identifier: Apache-2.0
mod
model_manager
;
pub
use
model_manager
::{
ModelManager
,
ModelManagerError
,
RuntimeConfigsWithNotify
};
pub
use
model_manager
::{
ModelManager
,
ModelManagerError
};
pub
(
crate
)
mod
runtime_configs
;
pub
use
runtime_configs
::{
RuntimeConfigs
,
RuntimeConfigsSubscriber
};
mod
watcher
;
pub
use
watcher
::{
ModelUpdate
,
ModelWatcher
};
...
...
lib/llm/src/discovery/model_manager.rs
View file @
584020f4
...
...
@@ -8,13 +8,14 @@ use std::{
use
dashmap
::{
DashMap
,
mapref
::
entry
::
Entry
};
use
parking_lot
::{
Mutex
,
RwLock
};
use
tokio
::
sync
::
{
Notify
,
oneshot
}
;
use
tokio
::
sync
::
oneshot
;
use
crate
::
discovery
::
KvWorkerMonitor
;
use
crate
::
discovery
::
runtime_configs
::
RuntimeConfigs
;
use
dynamo_runtime
::{
component
::{
Client
,
Endpoint
,
build_transport_type
},
discovery
::
{
DiscoveryQuery
,
DiscoverySpec
,
watch_and_extract_field
}
,
discovery
::
DiscoverySpec
,
prelude
::
DistributedRuntimeProvider
,
protocols
::
EndpointId
,
};
...
...
@@ -24,7 +25,7 @@ use crate::{
KvRouter
,
KvRouterConfig
,
protocols
::
WorkerId
,
router_endpoint_id
,
scheduler
::
DefaultWorkerSelector
,
},
local_model
::
runtime_config
::
{
DisaggregatedEndpoint
,
ModelRuntimeConfig
},
local_model
::
runtime_config
::
DisaggregatedEndpoint
,
model_card
::
ModelDeploymentCard
,
model_type
::
ModelType
,
types
::{
...
...
@@ -81,14 +82,8 @@ pub struct ModelManager {
/// Runtime configs per endpoint using DashMap for lock-free access.
/// Outer DashMap: keyed by EndpointId
/// Inner RuntimeConfigsWithNotify: shared with KvScheduler
runtime_configs
:
DashMap
<
EndpointId
,
Arc
<
RuntimeConfigsWithNotify
>>
,
}
/// Runtime configs for an endpoint with a notify for change notifications.
pub
struct
RuntimeConfigsWithNotify
{
pub
configs
:
DashMap
<
WorkerId
,
Option
<
ModelRuntimeConfig
>>
,
pub
notify
:
Notify
,
/// Inner RuntimeConfigs: shared with KvScheduler
runtime_configs
:
DashMap
<
EndpointId
,
Arc
<
RuntimeConfigs
>>
,
}
impl
Default
for
ModelManager
{
...
...
@@ -621,12 +616,12 @@ impl ModelManager {
}
/// Get or create a runtime config watcher for an endpoint.
/// Spawns a background task to watch
DiscoveryQuery::EndpointModel
s.
/// Returns a shared RuntimeConfigs
WithNotify
that KvScheduler can use directly.
/// Spawns a background task to watch
for worker config change
s.
/// Returns a shared RuntimeConfigs that KvScheduler can use directly.
pub
async
fn
get_or_create_runtime_config_watcher
(
&
self
,
endpoint
:
&
Endpoint
,
)
->
anyhow
::
Result
<
Arc
<
RuntimeConfigs
WithNotify
>>
{
)
->
anyhow
::
Result
<
Arc
<
RuntimeConfigs
>>
{
let
endpoint_id
=
endpoint
.id
();
// Fast path: return existing if present
...
...
@@ -635,10 +630,7 @@ impl ModelManager {
}
// Atomic get-or-insert to avoid TOCTOU race
let
inner
=
Arc
::
new
(
RuntimeConfigsWithNotify
{
configs
:
DashMap
::
new
(),
notify
:
Notify
::
new
(),
});
let
inner
=
Arc
::
new
(
RuntimeConfigs
::
new
());
let
(
result
,
is_new
)
=
match
self
.runtime_configs
.entry
(
endpoint_id
)
{
Entry
::
Occupied
(
e
)
=>
(
e
.get
()
.clone
(),
false
),
Entry
::
Vacant
(
e
)
=>
{
...
...
@@ -649,8 +641,7 @@ impl ModelManager {
// Only spawn watcher if we were the one who inserted
if
is_new
{
self
.spawn_runtime_config_watcher
(
endpoint
,
result
.clone
())
.await
?
;
result
.start_watcher
(
endpoint
)
.await
?
;
}
Ok
(
result
)
...
...
@@ -668,126 +659,6 @@ impl ModelManager {
config_ref
.as_ref
()
?
.disaggregated_endpoint
.clone
()
}
/// Spawn background task to watch runtime configs via discovery.
/// Blocks until at least one worker with a runtime config is available.
async
fn
spawn_runtime_config_watcher
(
&
self
,
endpoint
:
&
Endpoint
,
inner
:
Arc
<
RuntimeConfigsWithNotify
>
,
)
->
anyhow
::
Result
<
()
>
{
let
component
=
endpoint
.component
();
let
cancellation_token
=
component
.drt
()
.primary_token
();
// Set up discovery watch for EndpointModels
let
discovery
=
component
.drt
()
.discovery
();
let
endpoint_id
=
endpoint
.id
();
let
discovery_key
=
DiscoveryQuery
::
EndpointModels
{
namespace
:
endpoint_id
.namespace
.clone
(),
component
:
endpoint_id
.component
.clone
(),
endpoint
:
endpoint_id
.name
.clone
(),
};
let
discovery_stream
=
discovery
.list_and_watch
(
discovery_key
.clone
(),
Some
(
cancellation_token
.clone
()))
.await
?
;
// Extract runtime_config from ModelDeploymentCard
let
mut
runtime_configs_rx
=
watch_and_extract_field
(
discovery_stream
,
|
card
:
ModelDeploymentCard
|
{
card
.runtime_config
});
// Also watch instance IDs
let
client
=
endpoint
.client
()
.await
?
;
let
mut
instance_ids_rx
=
client
.instance_avail_watcher
();
// Wait for at least one worker with runtime config before proceeding.
// This ensures the DashMap is populated before KvScheduler starts.
tracing
::
info!
(
"ModelManager: Waiting for at least one worker with runtime config..."
);
runtime_configs_rx
.changed
()
.await
.map_err
(|
_
|
anyhow
::
anyhow!
(
"runtime configs watch sender shutdown while waiting"
))
?
;
// Populate initial state
{
let
instance_ids
=
instance_ids_rx
.borrow
();
let
configs
=
runtime_configs_rx
.borrow
();
for
worker_id
in
instance_ids
.iter
()
{
let
config
=
configs
.get
(
worker_id
)
.cloned
();
inner
.configs
.insert
(
*
worker_id
,
config
);
}
tracing
::
info!
(
"ModelManager: Found {} workers, proceeding"
,
inner
.configs
.len
()
);
}
// Spawn background task to update configs for future changes
let
cancel_token
=
cancellation_token
.clone
();
tokio
::
spawn
(
async
move
{
tracing
::
trace!
(
"ModelManager runtime config watcher started"
);
loop
{
// Wait for either instances or configs to change
tokio
::
select!
{
_
=
cancel_token
.cancelled
()
=>
{
tracing
::
trace!
(
"ModelManager runtime config watcher shutting down"
);
break
;
}
result
=
instance_ids_rx
.changed
()
=>
{
if
result
.is_err
()
{
tracing
::
warn!
(
"instance IDs watch sender shutdown in ModelManager"
);
break
;
}
}
result
=
runtime_configs_rx
.changed
()
=>
{
if
result
.is_err
()
{
tracing
::
warn!
(
"runtime configs watch sender shutdown in ModelManager"
);
break
;
}
}
}
// Get the latest values from both channels
let
new_instance_ids
=
instance_ids_rx
.borrow_and_update
()
.clone
();
let
new_configs
=
runtime_configs_rx
.borrow_and_update
()
.clone
();
// Update the DashMap
// First, remove workers that no longer exist
let
current_workers
:
HashSet
<
WorkerId
>
=
inner
.configs
.iter
()
.map
(|
r
|
*
r
.key
())
.collect
();
let
new_workers
:
HashSet
<
WorkerId
>
=
new_instance_ids
.iter
()
.copied
()
.collect
();
for
removed_worker
in
current_workers
.difference
(
&
new_workers
)
{
inner
.configs
.remove
(
removed_worker
);
}
// Then, add/update workers
for
worker_id
in
&
new_instance_ids
{
let
config
=
new_configs
.get
(
worker_id
)
.cloned
();
if
config
.is_some
()
{
let
prev_config
=
inner
.configs
.get
(
worker_id
);
if
prev_config
.as_ref
()
.map
(|
r
|
r
.value
())
!=
Some
(
&
config
)
{
tracing
::
info!
(
"ModelManager: Runtime config found for worker_id: {worker_id}"
);
}
}
inner
.configs
.insert
(
*
worker_id
,
config
);
}
// Notify waiters that configs have changed
inner
.notify
.notify_waiters
();
tracing
::
trace!
(
"ModelManager: Updated runtime_configs with {} workers"
,
inner
.configs
.len
()
);
}
tracing
::
trace!
(
"ModelManager runtime config watcher shutting down"
);
});
Ok
(())
}
/// Lists all models that have worker monitors (and thus busy thresholds) configured.
///
/// Returns a vector of (model_name, active_decode_blocks_threshold, active_prefill_tokens_threshold) tuples.
...
...
lib/llm/src/discovery/runtime_configs.rs
0 → 100644
View file @
584020f4
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
collections
::{
HashMap
,
HashSet
};
use
std
::
sync
::
Arc
;
use
dashmap
::
DashMap
;
use
tokio
::
sync
::
watch
;
use
dynamo_runtime
::
component
::
Endpoint
;
use
dynamo_runtime
::
discovery
::{
DiscoveryQuery
,
watch_and_extract_field
};
use
dynamo_runtime
::
prelude
::
DistributedRuntimeProvider
;
use
crate
::
kv_router
::
protocols
::
WorkerId
;
use
crate
::
local_model
::
runtime_config
::
ModelRuntimeConfig
;
use
crate
::
model_card
::
ModelDeploymentCard
;
/// Runtime configs for an endpoint with watch-based change notifications.
/// Call `subscribe()` to get a subscriber with its own watch receiver.
pub
struct
RuntimeConfigs
{
pub
configs
:
Arc
<
DashMap
<
WorkerId
,
Option
<
ModelRuntimeConfig
>>>
,
change_tx
:
watch
::
Sender
<
u64
>
,
}
impl
RuntimeConfigs
{
pub
(
crate
)
fn
new
()
->
Self
{
let
(
change_tx
,
_
)
=
watch
::
channel
(
0u64
);
Self
{
configs
:
Arc
::
new
(
DashMap
::
new
()),
change_tx
,
}
}
/// Create a subscriber that can wait for config changes.
/// Each subscriber has its own watch receiver, so notifications are not lost.
pub
fn
subscribe
(
&
self
)
->
RuntimeConfigsSubscriber
{
RuntimeConfigsSubscriber
{
configs
:
self
.configs
.clone
(),
change_rx
:
self
.change_tx
.subscribe
(),
}
}
/// Notify all subscribers of a change (internal use only).
fn
notify_change
(
&
self
)
{
// Increment counter to notify subscribers
self
.change_tx
.send_modify
(|
v
|
*
v
=
v
.wrapping_add
(
1
));
}
/// Returns the number of workers in the configs.
pub
fn
num_workers
(
&
self
)
->
usize
{
self
.configs
.len
()
}
/// Update configs with new worker instances and their configs.
/// Notifies subscribers if a config with Some value is added or a worker is removed.
pub
(
crate
)
fn
update
(
&
self
,
new_instance_ids
:
&
[
WorkerId
],
new_configs
:
&
HashMap
<
WorkerId
,
ModelRuntimeConfig
>
,
)
{
// First, remove workers that no longer exist
let
current_workers
:
HashSet
<
WorkerId
>
=
self
.configs
.iter
()
.map
(|
r
|
*
r
.key
())
.collect
();
let
new_workers
:
HashSet
<
WorkerId
>
=
new_instance_ids
.iter
()
.copied
()
.collect
();
let
mut
worker_removed
=
false
;
for
removed_worker
in
current_workers
.difference
(
&
new_workers
)
{
self
.configs
.remove
(
removed_worker
);
worker_removed
=
true
;
}
// Then, add/update workers
// Track if any config became Some (for notify)
let
mut
config_added
=
false
;
for
worker_id
in
new_instance_ids
{
let
config
=
new_configs
.get
(
worker_id
)
.cloned
();
if
config
.is_some
()
{
let
prev_config
=
self
.configs
.get
(
worker_id
);
let
was_none
=
prev_config
.as_ref
()
.map
(|
r
|
r
.value
()
.is_none
())
.unwrap_or
(
true
);
if
was_none
{
tracing
::
info!
(
"RuntimeConfigs: config found for worker_id: {worker_id}"
);
config_added
=
true
;
}
}
self
.configs
.insert
(
*
worker_id
,
config
);
}
// Notify when a config with Some value is added OR a worker is removed
if
config_added
||
worker_removed
{
self
.notify_change
();
}
}
/// Spawn background task to watch runtime configs via discovery.
/// Does not block - consumers should use `subscribe().wait_for_some()` if they need workers.
pub
(
crate
)
async
fn
start_watcher
(
self
:
&
Arc
<
Self
>
,
endpoint
:
&
Endpoint
)
->
anyhow
::
Result
<
()
>
{
let
component
=
endpoint
.component
();
let
cancellation_token
=
component
.drt
()
.primary_token
();
// Set up discovery watch for EndpointModels
let
discovery
=
component
.drt
()
.discovery
();
let
endpoint_id
=
endpoint
.id
();
let
discovery_key
=
DiscoveryQuery
::
EndpointModels
{
namespace
:
endpoint_id
.namespace
.clone
(),
component
:
endpoint_id
.component
.clone
(),
endpoint
:
endpoint_id
.name
.clone
(),
};
let
discovery_stream
=
discovery
.list_and_watch
(
discovery_key
.clone
(),
Some
(
cancellation_token
.clone
()))
.await
?
;
// Extract runtime_config from ModelDeploymentCard
let
mut
runtime_configs_rx
=
watch_and_extract_field
(
discovery_stream
,
|
card
:
ModelDeploymentCard
|
{
card
.runtime_config
});
// Also watch instance IDs
let
client
=
endpoint
.client
()
.await
?
;
let
mut
instance_ids_rx
=
client
.instance_avail_watcher
();
// Spawn background task to watch for config changes
// Note: We don't block here - consumers should wait on notify for configs they need
let
inner
=
self
.clone
();
let
cancel_token
=
cancellation_token
.clone
();
tokio
::
spawn
(
async
move
{
tracing
::
trace!
(
"RuntimeConfigs watcher started"
);
loop
{
// Wait for either instances or configs to change
tokio
::
select!
{
_
=
cancel_token
.cancelled
()
=>
{
tracing
::
trace!
(
"RuntimeConfigs watcher shutting down"
);
break
;
}
result
=
instance_ids_rx
.changed
()
=>
{
if
result
.is_err
()
{
tracing
::
warn!
(
"instance IDs watch sender shutdown"
);
break
;
}
}
result
=
runtime_configs_rx
.changed
()
=>
{
if
result
.is_err
()
{
tracing
::
warn!
(
"runtime configs watch sender shutdown"
);
break
;
}
}
}
// Get the latest values from both channels
let
new_instance_ids
=
instance_ids_rx
.borrow_and_update
()
.clone
();
let
new_configs
=
runtime_configs_rx
.borrow_and_update
()
.clone
();
inner
.update
(
&
new_instance_ids
,
&
new_configs
);
tracing
::
trace!
(
"RuntimeConfigs: Updated with {} workers"
,
inner
.configs
.len
()
);
}
tracing
::
trace!
(
"RuntimeConfigs watcher stopped"
);
});
Ok
(())
}
}
/// A subscriber to runtime config changes.
/// Each subscriber has its own watch receiver, ensuring no notifications are lost.
pub
struct
RuntimeConfigsSubscriber
{
pub
configs
:
Arc
<
DashMap
<
WorkerId
,
Option
<
ModelRuntimeConfig
>>>
,
pub
change_rx
:
watch
::
Receiver
<
u64
>
,
}
impl
RuntimeConfigsSubscriber
{
/// Wait until at least one worker has a Some config.
/// Returns the list of worker IDs that have configs.
/// This is race-safe: checks the DashMap first, only waits if empty.
/// Returns empty vec if the sender is dropped (shutdown).
pub
async
fn
wait_for_some
(
&
mut
self
)
->
Vec
<
WorkerId
>
{
loop
{
let
ready
:
Vec
<
WorkerId
>
=
self
.configs
.iter
()
.filter
(|
r
|
r
.value
()
.is_some
())
.map
(|
r
|
*
r
.key
())
.collect
();
if
!
ready
.is_empty
()
{
return
ready
;
}
// If sender dropped (shutdown), return empty rather than loop forever
if
self
.change_rx
.changed
()
.await
.is_err
()
{
tracing
::
warn!
(
"RuntimeConfigsSubscriber: sender dropped during wait_for_some"
);
return
vec!
[];
}
}
}
}
lib/llm/src/hub.rs
View file @
584020f4
...
...
@@ -4,6 +4,7 @@
use
std
::
env
;
use
std
::
path
::{
Path
,
PathBuf
};
use
hf_hub
::
Cache
;
use
modelexpress_client
::{
Client
as
MxClient
,
ClientConfig
as
MxClientConfig
,
ModelProvider
as
MxModelProvider
,
};
...
...
@@ -11,14 +12,77 @@ use modelexpress_common::download as mx;
use
dynamo_runtime
::
config
::
environment_names
::
model
as
env_model
;
/// Check if a model is already cached in the HuggingFace hub cache directory.
/// Returns the path to the cached model directory if found, None otherwise.
///
/// Uses hf-hub's Cache API to check for cached files. For tokenizer-only downloads
/// (ignore_weights=true), we check for config.json and tokenizer files.
/// For full downloads, we also require weight files to be present.
fn
get_cached_model_path
(
model_name
:
&
str
,
ignore_weights
:
bool
)
->
Option
<
PathBuf
>
{
let
cache
=
Cache
::
new
(
get_model_express_cache_dir
());
let
repo
=
cache
.model
(
model_name
.to_string
());
// Check for required config file
let
config_path
=
repo
.get
(
"config.json"
)
?
;
// Check for tokenizer files (at least one must exist)
let
has_tokenizer
=
repo
.get
(
"tokenizer.json"
)
.is_some
()
||
repo
.get
(
"tokenizer_config.json"
)
.is_some
();
if
!
has_tokenizer
{
return
None
;
}
// For full downloads, check for weight files
if
!
ignore_weights
{
// Check common weight file patterns - at least one must exist
let
has_weights
=
repo
.get
(
"model.safetensors"
)
.is_some
()
||
repo
.get
(
"pytorch_model.bin"
)
.is_some
()
||
repo
.get
(
"model.safetensors.index.json"
)
.is_some
()
||
repo
.get
(
"pytorch_model.bin.index.json"
)
.is_some
();
if
!
has_weights
{
return
None
;
}
}
// Return the parent directory (snapshot dir) containing the model files
let
snapshot_path
=
config_path
.parent
()
?
.to_path_buf
();
tracing
::
info!
(
"Found cached model '{model_name}' at {snapshot_path:?}, skipping download"
);
Some
(
snapshot_path
)
}
/// Check if offline mode is enabled via HF_HUB_OFFLINE environment variable.
fn
is_offline_mode
()
->
bool
{
env
::
var
(
env_model
::
huggingface
::
HF_HUB_OFFLINE
)
.map
(|
v
|
v
==
"1"
||
v
.to_lowercase
()
==
"true"
)
.unwrap_or
(
false
)
}
/// Download a model using ModelExpress client. The client first requests for the model
/// from the server and fallbacks to direct download in case of server failure.
/// If ignore_weights is true, model weight files will be skipped
/// Returns the path to the model files
///
/// If HF_HUB_OFFLINE=1 is set and the model is already cached, returns the cached
/// path without making any API calls to HuggingFace.
pub
async
fn
from_hf
(
name
:
impl
AsRef
<
Path
>
,
ignore_weights
:
bool
)
->
anyhow
::
Result
<
PathBuf
>
{
let
name
=
name
.as_ref
();
let
model_name
=
name
.display
()
.to_string
();
// In offline mode, check cache first and return immediately if found
if
is_offline_mode
()
{
if
let
Some
(
cached_path
)
=
get_cached_model_path
(
&
model_name
,
ignore_weights
)
{
tracing
::
info!
(
"Offline mode: using cached model '{model_name}' without API validation"
);
return
Ok
(
cached_path
);
}
tracing
::
warn!
(
"Offline mode enabled but model '{model_name}' not found in cache, attempting download anyway"
);
}
let
mut
config
:
MxClientConfig
=
MxClientConfig
::
default
();
if
let
Ok
(
endpoint
)
=
env
::
var
(
env_model
::
model_express
::
MODEL_EXPRESS_URL
)
{
config
=
config
.with_endpoint
(
endpoint
);
...
...
lib/llm/src/kv_router.rs
View file @
584020f4
...
...
@@ -9,7 +9,7 @@ use anyhow::Result;
use
derive_builder
::
Builder
;
use
dynamo_runtime
::{
component
::{
Client
,
Endpoint
},
discovery
::{
DiscoveryQuery
,
EventTransportKind
,
watch_and_extract_field
},
discovery
::{
DiscoveryQuery
,
EventTransportKind
},
pipeline
::{
AsyncEngine
,
AsyncEngineContextProvider
,
Error
,
ManyOut
,
PushRouter
,
ResponseStream
,
SingleIn
,
async_trait
,
...
...
@@ -41,7 +41,7 @@ pub use prefill_router::PrefillRouter;
use
worker_query
::
WorkerQueryClient
;
use
crate
::{
discovery
::
RuntimeConfigs
WithNotify
,
discovery
::
RuntimeConfigs
,
kv_router
::{
approx
::
PruneConfig
,
indexer
::{
KvIndexer
,
KvIndexerInterface
,
KvRouterError
},
...
...
@@ -55,7 +55,6 @@ use crate::{
subscriber
::{
start_kv_router_background
,
start_kv_router_background_event_plane
},
},
local_model
::
runtime_config
::
ModelRuntimeConfig
,
model_card
::
ModelDeploymentCard
,
preprocessor
::
PreprocessedRequest
,
protocols
::
common
::
llm_backend
::
LLMEngineOutput
,
protocols
::
common
::
timing
::
RequestPhase
,
...
...
@@ -332,7 +331,7 @@ impl KvRouter {
pub
async
fn
new
(
endpoint
:
Endpoint
,
client
:
Client
,
workers_with_configs
:
Arc
<
RuntimeConfigs
WithNotify
>
,
workers_with_configs
:
Arc
<
RuntimeConfigs
>
,
block_size
:
u32
,
selector
:
Option
<
Box
<
dyn
WorkerSelector
+
Send
+
Sync
>>
,
kv_router_config
:
Option
<
KvRouterConfig
>
,
...
...
@@ -342,23 +341,6 @@ impl KvRouter {
let
component
=
endpoint
.component
();
let
cancellation_token
=
component
.drt
()
.primary_token
();
// Watch for runtime config updates via discovery interface
// (still needed for WorkerQueryClient and background tasks)
let
discovery
=
component
.drt
()
.discovery
();
let
endpoint_id
=
endpoint
.id
();
let
discovery_key
=
DiscoveryQuery
::
EndpointModels
{
namespace
:
endpoint_id
.namespace
.clone
(),
component
:
endpoint_id
.component
.clone
(),
endpoint
:
endpoint_id
.name
.clone
(),
};
let
discovery_stream
=
discovery
.list_and_watch
(
discovery_key
.clone
(),
Some
(
cancellation_token
.clone
()))
.await
?
;
let
runtime_configs_rx
=
watch_and_extract_field
(
discovery_stream
,
|
card
:
ModelDeploymentCard
|
{
card
.runtime_config
});
let
indexer
=
if
kv_router_config
.overlap_score_weight
==
0.0
{
// When overlap_score_weight is zero, we don't need to track prefixes
Indexer
::
None
...
...
@@ -385,6 +367,9 @@ impl KvRouter {
))
};
// Wait for at least one worker with a known runtime config before starting scheduler
workers_with_configs
.subscribe
()
.wait_for_some
()
.await
;
let
scheduler
=
KvScheduler
::
start
(
component
.clone
(),
block_size
,
...
...
@@ -397,30 +382,27 @@ impl KvRouter {
// Initialize worker query client using namespace abstraction
// (created before background task so we can use it for startup recovery)
let
worker_query_client
=
worker_query
::
WorkerQueryClient
::
new
(
component
.clone
(),
runtime_configs_rx
.clone
());
// Uses a subscriber from workers_with_configs
let
worker_query_client
=
worker_query
::
WorkerQueryClient
::
new
(
component
.clone
(),
workers_with_configs
.subscribe
(),
);
tracing
::
info!
(
"Worker query client initialized"
);
// Start KV event subscriber background process (only when use_kv_events is enabled)
// model_manager.get_or_create_runtime_config_watcher() guarantees at least one worker exists.
if
kv_router_config
.use_kv_events
&&
let
Indexer
::
KvIndexer
(
ref
kv_indexer
)
=
indexer
{
// model_manager guarantees workers_with_configs is populated
// Wait for at least one worker before starting the subscriber
while
workers_with_configs
.configs
.is_empty
()
{
tracing
::
info!
(
"KV router waiting for at least one worker..."
);
workers_with_configs
.notify
.notified
()
.await
;
}
let
count
=
workers_with_configs
.configs
.len
();
let
all_local_indexer
=
workers_with_configs
.configs
.iter
()
.filter_map
(|
r
|
r
.value
()
.as_ref
()
.map
(|
c
|
c
.enable_local_indexer
))
.all
(|
b
|
b
);
tracing
::
info!
(
"Found {count} worker(s), starting KV event subscriber"
);
tracing
::
info!
(
"Found {} worker(s), starting KV event subscriber"
,
workers_with_configs
.num_workers
()
);
let
transport_kind
=
EventTransportKind
::
from_env_or_default
();
...
...
@@ -436,7 +418,8 @@ impl KvRouter {
}
}
else
{
tracing
::
info!
(
"All {count} workers have local_indexer enabled, using NATS Core subscription"
"All {} workers have local_indexer enabled, using NATS Core subscription"
,
workers_with_configs
.num_workers
()
);
}
...
...
@@ -447,7 +430,7 @@ impl KvRouter {
cancellation_token
.clone
(),
worker_query
::
WorkerQueryClient
::
new
(
component
.clone
(),
runtime_configs_rx
.clon
e
(),
workers_with_configs
.subscrib
e
(),
),
transport_kind
,
)
...
...
lib/llm/src/kv_router/publisher.rs
View file @
584020f4
...
...
@@ -921,7 +921,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
// -------------------------------------------------------------------------
/// Metrics data passed through the channel for NATS publishing
#[derive(Debug,
Clone,
Default)]
#[derive(Debug,
Clone,
Default
,
PartialEq
)]
struct
WorkerMetrics
{
dp_rank
:
DpRank
,
active_decode_blocks
:
u64
,
...
...
@@ -982,7 +982,7 @@ impl WorkerMetricsPublisher {
};
let
mut
rx
=
nats_rx
;
let
mut
last_
active_decode_blocks
:
Option
<
u64
>
=
Some
(
0
)
;
let
mut
last_
metrics
:
Option
<
WorkerMetrics
>
=
None
;
let
mut
pending_publish
:
Option
<
WorkerMetrics
>
=
None
;
let
mut
publish_timer
=
Box
::
pin
(
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_secs
(
0
)));
...
...
@@ -1001,16 +1001,13 @@ impl WorkerMetricsPublisher {
let
metrics
=
rx
.borrow_and_update
()
.clone
();
// Check if active_decode_blocks has changed
let
has_changed
=
match
last_active_decode_blocks
{
Some
(
last
)
=>
last
!=
metrics
.active_decode_blocks
,
None
=>
true
,
// First time, consider it changed
};
// Check if metrics have changed
let
has_changed
=
last_metrics
.as_ref
()
!=
Some
(
&
metrics
);
// If
load
metrics changed, schedule a publish
// If metrics changed, schedule a publish
if
has_changed
{
pending_publish
=
Some
(
metrics
.clone
());
last_
active_decode_blocks
=
Some
(
metrics
.active_decode_block
s
);
last_
metrics
=
Some
(
metric
s
);
// Start the 1ms timer
publish_timer
.as_mut
()
.reset
(
...
...
lib/llm/src/kv_router/scheduler.rs
View file @
584020f4
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
crate
::
discovery
::
RuntimeConfigs
WithNotify
;
use
crate
::
discovery
::
RuntimeConfigs
;
use
crate
::
local_model
::
runtime_config
::
ModelRuntimeConfig
;
use
anyhow
::
Result
;
use
dynamo_runtime
::
component
::
Component
;
...
...
@@ -98,7 +98,7 @@ impl KvScheduler {
pub
async
fn
start
(
component
:
Component
,
block_size
:
u32
,
workers_with_configs
:
Arc
<
RuntimeConfigs
WithNotify
>
,
workers_with_configs
:
Arc
<
RuntimeConfigs
>
,
selector
:
Option
<
Box
<
dyn
WorkerSelector
+
Send
+
Sync
>>
,
replica_sync
:
bool
,
router_id
:
u64
,
...
...
@@ -106,7 +106,7 @@ impl KvScheduler {
let
selector
=
selector
.unwrap_or
(
Box
::
new
(
DefaultWorkerSelector
::
default
()));
// Get initial workers from DashMap for slot initialization.
//
ModelManager guarantees
at least one worker is present
before KvRouter::new() is called
.
//
Caller must ensure
at least one worker is present
(via wait_for_some)
.
let
initial_workers
:
HashMap
<
WorkerId
,
Option
<
ModelRuntimeConfig
>>
=
workers_with_configs
.configs
.iter
()
...
...
@@ -126,9 +126,11 @@ impl KvScheduler {
);
// Spawn background task to sync slots with DashMap when notified of changes.
// ModelManager's watcher updates the DashMap and notifies; we wait on
notify
here.
// ModelManager's watcher updates the DashMap and notifies; we wait on
watch receiver
here.
let
slots_monitor
=
slots
.clone
();
let
workers_monitor
=
workers_with_configs
.clone
();
let
subscriber
=
workers_with_configs
.subscribe
();
let
configs_monitor
=
subscriber
.configs
;
let
mut
change_rx
=
subscriber
.change_rx
;
let
monitor_cancel_token
=
component
.drt
()
.child_token
();
tokio
::
spawn
(
async
move
{
tracing
::
trace!
(
"KvScheduler workers monitoring task started"
);
...
...
@@ -141,13 +143,17 @@ impl KvScheduler {
tracing
::
trace!
(
"KvScheduler workers monitoring task shutting down"
);
break
;
}
_
=
workers_monitor
.notify
.notified
()
=>
{}
result
=
change_rx
.changed
()
=>
{
if
result
.is_err
()
{
tracing
::
warn!
(
"KvScheduler: config watch sender dropped, shutting down"
);
break
;
}
}
}
// Get current workers from DashMap
let
current_workers
:
HashMap
<
WorkerId
,
Option
<
ModelRuntimeConfig
>>
=
workers_monitor
.configs
configs_monitor
.iter
()
.map
(|
r
|
(
*
r
.key
(),
r
.value
()
.clone
()))
.collect
();
...
...
lib/llm/src/kv_router/subscriber.rs
View file @
584020f4
...
...
@@ -55,9 +55,9 @@ const WORKER_QUERY_INITIAL_BACKOFF_MS: u64 = 200;
// Discovery Helpers
// ============================================================================
///
Wait for at least one worker instance to be discovered
.
///
Returns a peekable stream of discovery events for the generate endpo
in
t
.
async
fn
wait_for_worker_instance
(
///
Get the instance discovery stream for monitoring worker add/remove events
.
///
Waits for at least one instance to be discovered before return
in
g
.
async
fn
get_instance_discovery_stream
(
component
:
&
Component
,
cancellation_token
:
&
CancellationToken
,
)
->
Result
<
std
::
pin
::
Pin
<
Box
<
dyn
futures
::
Stream
<
Item
=
Result
<
DiscoveryEvent
>>
+
Send
>>>
{
...
...
@@ -524,7 +524,7 @@ pub async fn start_kv_router_background(
// Wait for at least one worker instance before proceeding
let
mut
instance_event_stream
=
wait_for_worker_instance
(
&
component
,
&
cancellation_token
)
.await
?
;
get_instance_discovery_stream
(
&
component
,
&
cancellation_token
)
.await
?
;
// Watch for router deletions to clean up orphaned consumers via discovery
let
generate_endpoint
=
component
.endpoint
(
"generate"
);
...
...
@@ -762,7 +762,7 @@ pub async fn start_kv_router_background_event_plane(
kv_events_tx
:
mpsc
::
Sender
<
RouterEvent
>
,
remove_worker_tx
:
mpsc
::
Sender
<
WorkerId
>
,
cancellation_token
:
CancellationToken
,
worker_query_client
:
WorkerQueryClient
,
mut
worker_query_client
:
WorkerQueryClient
,
transport_kind
:
EventTransportKind
,
)
->
Result
<
()
>
{
// Subscribe to KV events using the selected event plane transport
...
...
@@ -792,44 +792,35 @@ pub async fn start_kv_router_background_event_plane(
}
}
// Wait for at least one worker instance before proceeding
let
mut
instance_event_stream
=
wait_for_worker_instance
(
&
component
,
&
cancellation_token
)
.await
?
;
// Drain and process all existing workers before spawning the background loop.
// list_and_watch returns existing instances first, so we poll with a short timeout
// to process all initial workers synchronously before the router becomes "ready".
loop
{
// Use a short timeout to detect when initial discovery events are exhausted
let
poll_result
=
tokio
::
time
::
timeout
(
Duration
::
from_millis
(
100
),
instance_event_stream
.next
())
.await
;
// Wait for at least one worker with a known runtime config before proceeding.
// This ensures we have actual config data (including enable_local_indexer) available.
tracing
::
info!
(
"KV subscriber waiting for at least one worker with runtime config..."
);
let
ready_workers
=
worker_query_client
.wait_for_ready
()
.await
;
tracing
::
info!
(
"KV subscriber found {} worker(s) with runtime config, proceeding"
,
ready_workers
.len
()
);
match
poll_result
{
Ok
(
Some
(
Ok
(
event
)))
=>
{
handle_worker_discovery
(
event
,
&
worker_query_client
,
&
kv_events_tx
,
&
remove_worker_tx
,
)
.await
;
}
Ok
(
Some
(
Err
(
e
)))
=>
{
tracing
::
warn!
(
"Error receiving discovery event during initial sync: {e}"
);
// Recover initial state from all ready workers
for
worker_id
in
&
ready_workers
{
if
worker_query_client
.has_local_indexer
(
*
worker_id
)
{
match
recover_from_worker
(
&
worker_query_client
,
*
worker_id
,
None
,
None
,
&
kv_events_tx
)
.await
{
Ok
(
count
)
=>
{
tracing
::
info!
(
"Successfully recovered {count} events from worker {worker_id}"
);
}
Ok
(
None
)
=>
{
// Stream ended
tracing
::
warn!
(
"Discovery stream ended during initial sync"
);
break
;
Err
(
e
)
=>
{
tracing
::
warn!
(
"Failed to recover from worker {worker_id}: {e}"
);
}
Err
(
_
)
=>
{
// Timeout - no more initial events
tracing
::
debug!
(
"Initial worker discovery sync complete"
);
break
;
}
}
}
// Get instance discovery stream for ongoing monitoring of worker add/remove events
let
mut
instance_event_stream
=
get_instance_discovery_stream
(
&
component
,
&
cancellation_token
)
.await
?
;
tokio
::
spawn
(
async
move
{
// Track last received event ID per worker for gap detection
let
mut
last_event_ids
:
HashMap
<
WorkerId
,
u64
>
=
HashMap
::
new
();
...
...
lib/llm/src/kv_router/worker_query.rs
View file @
584020f4
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
collections
::
HashMap
;
use
std
::
sync
::
Arc
;
use
anyhow
::{
Context
,
Result
};
...
...
@@ -11,46 +10,49 @@ use dynamo_runtime::pipeline::{
SingleIn
,
async_trait
,
network
::
Ingress
,
};
use
dynamo_runtime
::
protocols
::
maybe_error
::
MaybeError
;
use
tokio
::
sync
::
{
OnceCell
,
watch
}
;
use
tokio
::
sync
::
OnceCell
;
use
tokio_stream
::
StreamExt
;
use
crate
::
discovery
::
RuntimeConfigsSubscriber
;
use
crate
::
kv_router
::
WORKER_KV_INDEXER_QUERY_ENDPOINT
;
use
crate
::
kv_router
::
indexer
::{
LocalKvIndexer
,
WorkerKvQueryRequest
,
WorkerKvQueryResponse
};
use
crate
::
kv_router
::
protocols
::
WorkerId
;
use
crate
::
local_model
::
runtime_config
::
ModelRuntimeConfig
;
use
dynamo_runtime
::
stream
;
/// Router-side client for querying worker local KV indexers
///
/// Performs request/reply communication with workers via request plane endpoint routing.
/// (Only queries workers that have `enable_local_indexer=true` in their MDC user_data)
/// The client is spawned by KvRouter; it
watches same discovery stream as the router
.
/// The client is spawned by KvRouter; it
uses a subscriber from RuntimeConfigs
.
pub
struct
WorkerQueryClient
{
component
:
Component
,
///
Watch receiver for enable_local_indexer state per worker
model_runtime_config_rx
:
watch
::
Receiver
<
HashMap
<
WorkerId
,
Model
RuntimeConfig
>>
,
///
Subscriber for runtime configs (includes shared configs DashMap)
subscriber
:
RuntimeConfig
sSubscriber
,
router
:
OnceCell
<
Arc
<
PushRouter
<
WorkerKvQueryRequest
,
WorkerKvQueryResponse
>>>
,
}
impl
WorkerQueryClient
{
/// Create a new WorkerQueryClient with a watch receiver for local indexer states
pub
fn
new
(
component
:
Component
,
model_runtime_config_rx
:
watch
::
Receiver
<
HashMap
<
WorkerId
,
ModelRuntimeConfig
>>
,
)
->
Self
{
/// Create a new WorkerQueryClient with a subscriber to runtime configs
pub
fn
new
(
component
:
Component
,
subscriber
:
RuntimeConfigsSubscriber
)
->
Self
{
Self
{
component
,
model_runtime_config_rx
,
subscriber
,
router
:
OnceCell
::
new
(),
}
}
/// Wait until at least one worker has a known runtime config (Some).
/// Returns the list of worker IDs that have configs.
pub
async
fn
wait_for_ready
(
&
mut
self
)
->
Vec
<
WorkerId
>
{
self
.subscriber
.wait_for_some
()
.await
}
/// Check if a worker has local indexer enabled
pub
fn
has_local_indexer
(
&
self
,
worker_id
:
WorkerId
)
->
bool
{
self
.
model_runtime_config_rx
.
borrow
()
self
.
subscriber
.
configs
.get
(
&
worker_id
)
.
map
(|
config
|
config
.enable_local_indexer
)
.
and_then
(|
entry
|
entry
.value
()
.as_ref
()
.map
(|
c
|
c
.enable_local_indexer
)
)
.unwrap_or
(
false
)
}
...
...
lib/runtime/src/config/environment_names.rs
View file @
584020f4
...
...
@@ -301,6 +301,10 @@ pub mod model {
/// Hugging Face home directory
pub
const
HF_HOME
:
&
str
=
"HF_HOME"
;
/// Offline mode - skip API calls when model is cached
/// Set to "1" or "true" to enable
pub
const
HF_HUB_OFFLINE
:
&
str
=
"HF_HUB_OFFLINE"
;
}
}
...
...
@@ -436,6 +440,7 @@ mod tests {
model
::
huggingface
::
HF_TOKEN
,
model
::
huggingface
::
HF_HUB_CACHE
,
model
::
huggingface
::
HF_HOME
,
model
::
huggingface
::
HF_HUB_OFFLINE
,
// Event Plane
event_plane
::
DYN_EVENT_PLANE
,
event_plane
::
DYN_EVENT_PLANE_CODEC
,
...
...
lib/runtime/src/transports/nats.rs
View file @
584020f4
...
...
@@ -335,12 +335,6 @@ impl ClientOptions {
let
js_ctx
=
jetstream
::
new
(
client
.clone
());
// Validate JetStream is available
js_ctx
.query_account
()
.await
.map_err
(|
e
|
anyhow
::
anyhow!
(
"JetStream not available: {e}"
))
?
;
Ok
(
Client
{
client
,
js_ctx
})
}
}
...
...
tests/conftest.py
View file @
584020f4
...
...
@@ -298,7 +298,7 @@ class EtcdServer(ManagedProcess):
class
NatsServer
(
ManagedProcess
):
def
__init__
(
self
,
request
,
port
=
4222
,
timeout
=
300
):
def
__init__
(
self
,
request
,
port
=
4222
,
timeout
=
300
,
disable_jetstream
=
False
):
# Allocate a free port if port is 0
use_random_port
=
port
==
0
if
use_random_port
:
...
...
@@ -309,16 +309,16 @@ class NatsServer(ManagedProcess):
self
.
use_random_port
=
use_random_port
# Track if we allocated the port
self
.
_request
=
request
# Store for restart
self
.
_timeout
=
timeout
data_dir
=
tempfile
.
mkdtemp
(
prefix
=
"nats_"
)
self
.
_disable_jetstream
=
disable_jetstream
data_dir
=
tempfile
.
mkdtemp
(
prefix
=
"nats_"
)
if
not
disable_jetstream
else
None
command
=
[
"nats-server"
,
"-js"
,
"--trace"
,
"--store_dir"
,
data_dir
,
"-p"
,
str
(
port
),
]
if
not
disable_jetstream
and
data_dir
:
command
.
extend
([
"-js"
,
"--store_dir"
,
data_dir
])
super
().
__init__
(
command
=
command
,
timeout
=
timeout
,
...
...
@@ -326,9 +326,45 @@ class NatsServer(ManagedProcess):
terminate_existing
=
not
use_random_port
,
# Disabled for parallel test execution with random ports
data_dir
=
data_dir
,
health_check_ports
=
[
port
],
health_check_funcs
=
[
self
.
_nats_ready
],
log_dir
=
request
.
node
.
name
,
)
def
_nats_ready
(
self
,
timeout
:
float
=
5
)
->
bool
:
"""Verify NATS server is ready by connecting and optionally checking JetStream."""
import
asyncio
import
nats
async
def
check
():
try
:
nc
=
await
nats
.
connect
(
f
"nats://localhost:
{
self
.
port
}
"
,
connect_timeout
=
min
(
timeout
,
2
),
)
try
:
if
not
self
.
_disable_jetstream
:
# Verify JetStream is initialized
js
=
nc
.
jetstream
()
await
js
.
account_info
()
return
True
finally
:
await
nc
.
close
()
except
Exception
:
return
False
# Handle both sync and async contexts
try
:
asyncio
.
get_running_loop
()
# Check if we're in async context
# Already in async context - run in a thread to avoid blocking
import
concurrent.futures
with
concurrent
.
futures
.
ThreadPoolExecutor
()
as
pool
:
return
pool
.
submit
(
asyncio
.
run
,
check
()).
result
(
timeout
=
timeout
)
except
RuntimeError
:
# No running loop - safe to use asyncio.run()
return
asyncio
.
run
(
check
())
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
"""Release allocated port when server exits."""
try
:
...
...
@@ -344,9 +380,10 @@ class NatsServer(ManagedProcess):
"""Stop the NATS server for restart. Does not release port or clean up fully."""
_logger
.
info
(
f
"Stopping NATS server on port
{
self
.
port
}
"
)
self
.
_terminate_process_group
()
if
self
.
proc
:
proc
=
self
.
proc
# type: ignore[has-type]
if
proc
is
not
None
:
try
:
self
.
proc
.
wait
(
timeout
=
10
)
proc
.
wait
(
timeout
=
10
)
except
Exception
as
e
:
_logger
.
warning
(
f
"Error waiting for NATS process to stop:
{
e
}
"
)
self
.
proc
=
None
...
...
@@ -354,130 +391,130 @@ class NatsServer(ManagedProcess):
def
start
(
self
):
"""Restart a stopped NATS server with fresh state."""
_logger
.
info
(
f
"Starting NATS server on port
{
self
.
port
}
with fresh state"
)
# Clean up old data directory and create fresh one
if
self
.
data_dir
:
shutil
.
rmtree
(
self
.
data_dir
,
ignore_errors
=
True
)
# Clean up old data directory and create fresh one (only if JetStream enabled)
if
not
self
.
_disable_jetstream
:
old_data_dir
=
self
.
data_dir
# type: ignore[has-type]
if
old_data_dir
is
not
None
:
shutil
.
rmtree
(
old_data_dir
,
ignore_errors
=
True
)
self
.
data_dir
=
tempfile
.
mkdtemp
(
prefix
=
"nats_"
)
# Rebuild command
with new data_dir
# Rebuild command
self
.
command
=
[
"nats-server"
,
"-js"
,
"--trace"
,
"--store_dir"
,
self
.
data_dir
,
"-p"
,
str
(
self
.
port
),
]
if
not
self
.
_disable_jetstream
and
self
.
data_dir
:
self
.
command
.
extend
([
"-js"
,
"--store_dir"
,
self
.
data_dir
])
self
.
_start_process
()
self
.
_check_ports
(
self
.
_timeout
)
elapsed
=
self
.
_check_ports
(
self
.
_timeout
)
self
.
_check_funcs
(
self
.
_timeout
-
elapsed
)
class
SharedManagedProcess
:
"""Base class for ManagedProcess with file-based reference counting for multi-process sharing."""
"""Base class for persistent shared processes across pytest-xdist workers.
Simplified design: first worker starts the process on a dynamic port, it lives forever
(until the container dies). No ref counting, no teardown. Subsequent workers just
reuse via port check. This eliminates race conditions and simplifies the logic.
"""
def
__init__
(
self
,
request
,
tmp_path_factory
,
resource_name
:
str
,
port
:
int
,
start_
port
:
int
,
timeout
:
int
=
300
,
):
self
.
request
=
request
self
.
port
=
port
self
.
start_port
=
start_port
self
.
port
:
Optional
[
int
]
=
None
# Set when entering context
self
.
timeout
=
timeout
self
.
resource_name
=
resource_name
self
.
_server
:
Optional
[
ManagedProcess
]
=
None
self
.
_owns_process
=
False
root_tmp
=
Path
(
tempfile
.
gettempdir
())
/
"pytest_
ref_counting
"
root_tmp
=
Path
(
tempfile
.
gettempdir
())
/
"pytest_
shared_services
"
root_tmp
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
self
.
ref
_file
=
root_tmp
/
f
"
pytest_
{
resource_name
}
_
{
port
}
_ref_count
"
self
.
lock_file
=
str
(
self
.
ref
_file
)
+
".lock"
self
.
port
_file
=
root_tmp
/
f
"
{
resource_name
}
_port"
self
.
lock_file
=
str
(
self
.
port
_file
)
+
".lock"
def
_create_server
(
self
)
->
ManagedProcess
:
def
_create_server
(
self
,
port
:
int
)
->
ManagedProcess
:
"""Create the underlying server instance. Must be implemented by subclasses."""
raise
NotImplementedError
def
_read_ref_count
(
self
)
->
int
:
"""Read current reference count."""
if
self
.
ref_file
.
exists
():
def
_is_port_in_use
(
self
,
port
:
int
)
->
bool
:
"""Check if a port is in use (i.e., a process is listening on it)."""
import
socket
try
:
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
.
settimeout
(
1
)
result
=
sock
.
connect_ex
((
"localhost"
,
port
))
sock
.
close
()
return
result
==
0
# 0 means connection succeeded (port in use)
except
Exception
:
return
False
def
_read_port
(
self
)
->
Optional
[
int
]:
"""Read stored port from file."""
if
self
.
port_file
.
exists
():
try
:
return
int
(
self
.
ref
_file
.
read_text
().
strip
())
return
int
(
self
.
port
_file
.
read_text
().
strip
())
except
(
ValueError
,
IOError
):
return
0
return
0
def
_write_ref_count
(
self
,
count
:
int
):
"""Write reference count atomically."""
self
.
ref_file
.
write_text
(
str
(
count
))
def
_increment_ref_count
(
self
)
->
int
:
"""Increment reference count and return new count."""
count
=
self
.
_read_ref_count
()
count
+=
1
self
.
_write_ref_count
(
count
)
return
count
def
_decrement_ref_count
(
self
)
->
int
:
"""Decrement reference count and return new count."""
count
=
self
.
_read_ref_count
()
count
=
max
(
0
,
count
-
1
)
self
.
_write_ref_count
(
count
)
return
count
return
None
return
None
def
_write_port
(
self
,
port
:
int
):
"""Write port to file."""
self
.
port_file
.
write_text
(
str
(
port
))
def
__enter__
(
self
):
with
FileLock
(
self
.
lock_file
):
ref_count
=
self
.
_increment_ref_count
()
if
ref_count
==
1
:
# First reference - start the process
self
.
_server
=
self
.
_create_server
()
self
.
_server
.
__enter__
()
self
.
_owns_process
=
True
logging
.
info
(
f
"[
{
self
.
resource_name
}
] Started process (ref_count=1)"
)
else
:
# Process already running, just track reference
self
.
_owns_process
=
False
stored_port
=
self
.
_read_port
()
# Check if a process is already running on the stored port
if
stored_port
is
not
None
and
self
.
_is_port_in_use
(
stored_port
):
# Reuse existing process
self
.
port
=
stored_port
logging
.
info
(
f
"[
{
self
.
resource_name
}
] Reusing existing process
(ref_count=
{
ref_coun
t
}
)
"
f
"[
{
self
.
resource_name
}
] Reusing existing process
on port
{
self
.
por
t
}
"
)
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
with
FileLock
(
self
.
lock_file
):
ref_count
=
self
.
_decrement_ref_count
()
if
ref_count
==
0
and
self
.
_owns_process
:
# Last reference - stop the process
if
self
.
_server
:
self
.
_server
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
logging
.
info
(
f
"[
{
self
.
resource_name
}
] Stopped process (ref_count=0)"
)
elif
ref_count
==
0
:
# Last reference but we don't own it - shouldn't happen, but clean up ref file
if
self
.
ref_file
.
exists
():
self
.
ref_file
.
unlink
()
else
:
# Start new process
if
stored_port
is
not
None
:
logging
.
warning
(
f
"[
{
self
.
resource_name
}
]
Ref count reached 0 but we don't own proc
es
s
"
f
"[
{
self
.
resource_name
}
]
Stale port file: port
{
stored_port
}
not in use, starting fr
es
h
"
)
else
:
self
.
port
=
allocate_port
(
self
.
start_port
)
self
.
_write_port
(
self
.
port
)
self
.
_server
=
self
.
_create_server
(
self
.
port
)
self
.
_server
.
__enter__
()
logging
.
info
(
f
"[
{
self
.
resource_name
}
]
Released reference (ref_count=
{
ref_coun
t
}
)
"
f
"[
{
self
.
resource_name
}
]
Started process on port
{
self
.
por
t
}
"
)
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
# Never tear down - let the process live until the container dies.
# This avoids race conditions and simplifies the logic.
pass
class
SharedEtcdServer
(
SharedManagedProcess
):
"""EtcdServer with file-based reference counting for multi-process sharing."""
def
__init__
(
self
,
request
,
tmp_path_factory
,
port
=
23
79
,
timeout
=
300
):
super
().
__init__
(
request
,
tmp_path_factory
,
"etcd"
,
port
,
timeout
)
def
__init__
(
self
,
request
,
tmp_path_factory
,
start_
port
=
23
80
,
timeout
=
300
):
super
().
__init__
(
request
,
tmp_path_factory
,
"etcd"
,
start_
port
,
timeout
)
# Create a log directory for session-scoped servers
self
.
_log_dir
=
tempfile
.
mkdtemp
(
prefix
=
f
"pytest_
{
self
.
resource_name
}
_logs_"
)
def
_create_server
(
self
)
->
ManagedProcess
:
def
_create_server
(
self
,
port
:
int
)
->
ManagedProcess
:
"""Create EtcdServer instance."""
server
=
EtcdServer
(
self
.
request
,
port
=
self
.
port
,
timeout
=
self
.
timeout
)
server
=
EtcdServer
(
self
.
request
,
port
=
port
,
timeout
=
self
.
timeout
)
# Override log_dir since request.node.name is empty in session scope
server
.
log_dir
=
self
.
_log_dir
return
server
...
...
@@ -486,14 +523,27 @@ class SharedEtcdServer(SharedManagedProcess):
class
SharedNatsServer
(
SharedManagedProcess
):
"""NatsServer with file-based reference counting for multi-process sharing."""
def
__init__
(
self
,
request
,
tmp_path_factory
,
port
=
4222
,
timeout
=
300
):
super
().
__init__
(
request
,
tmp_path_factory
,
"nats"
,
port
,
timeout
)
def
__init__
(
self
,
request
,
tmp_path_factory
,
start_port
=
4223
,
timeout
=
300
,
disable_jetstream
=
False
,
):
super
().
__init__
(
request
,
tmp_path_factory
,
"nats"
,
start_port
,
timeout
)
# Create a log directory for session-scoped servers
self
.
_log_dir
=
tempfile
.
mkdtemp
(
prefix
=
f
"pytest_
{
self
.
resource_name
}
_logs_"
)
self
.
_disable_jetstream
=
disable_jetstream
def
_create_server
(
self
)
->
ManagedProcess
:
def
_create_server
(
self
,
port
:
int
)
->
ManagedProcess
:
"""Create NatsServer instance."""
server
=
NatsServer
(
self
.
request
,
port
=
self
.
port
,
timeout
=
self
.
timeout
)
server
=
NatsServer
(
self
.
request
,
port
=
port
,
timeout
=
self
.
timeout
,
disable_jetstream
=
self
.
_disable_jetstream
,
)
# Override log_dir since request.node.name is empty in session scope
server
.
log_dir
=
self
.
_log_dir
return
server
...
...
@@ -525,6 +575,27 @@ def request_plane(request):
return
getattr
(
request
,
"param"
,
"nats"
)
@
pytest
.
fixture
def
use_nats_core
(
request
):
"""
Whether to use NATS Core mode (local indexer) instead of JetStream. Defaults to False.
When True:
- NATS server starts without JetStream (-js flag omitted) for faster startup
- Tests should use enable_local_indexer=True in mocker_args
When False (default):
- NATS server starts with JetStream for KV event distribution
- Tests use JetStream-based indexer synchronization
To use NATS Core mode:
@pytest.mark.parametrize("use_nats_core", [True], indirect=True)
def test_example(runtime_services_dynamic_ports):
...
"""
return
getattr
(
request
,
"param"
,
False
)
@
pytest
.
fixture
()
def
runtime_services
(
request
,
store_kv
,
request_plane
):
"""
...
...
@@ -551,7 +622,7 @@ def runtime_services(request, store_kv, request_plane):
@
pytest
.
fixture
()
def
runtime_services_dynamic_ports
(
request
,
store_kv
,
request_plane
):
def
runtime_services_dynamic_ports
(
request
,
store_kv
,
request_plane
,
use_nats_core
):
"""Provide NATS and Etcd servers with truly dynamic ports per test.
This fixture actually allocates dynamic ports by passing port=0 to the servers.
...
...
@@ -566,6 +637,7 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane):
- If store_kv != "etcd", etcd is not started (returns None)
- NATS is always started when etcd is used, because KV events require NATS
regardless of the request_plane (tcp/nats only affects request transport)
- JetStream is enabled by default; disabled when use_nats_core=True for faster startup
Returns a tuple of (nats_process, etcd_process) where each has a .port attribute.
"""
...
...
@@ -573,23 +645,41 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane):
# Port cleanup is now handled in NatsServer and EtcdServer __exit__ methods
# Always start NATS when etcd is used - KV events require NATS regardless of request_plane
# When use_nats_core=True, disable JetStream for faster startup
if
store_kv
==
"etcd"
:
with
NatsServer
(
request
,
port
=
0
)
as
nats_process
:
with
NatsServer
(
request
,
port
=
0
,
disable_jetstream
=
use_nats_core
)
as
nats_process
:
with
EtcdServer
(
request
,
port
=
0
)
as
etcd_process
:
# Set environment variables for Rust/Python runtime to use. Note that xdist (parallel execution)
# will launch isolated tests in a new process, so no need to worry about environment pollution.
# Save original env vars (may be set by session-scoped fixture)
orig_nats
=
os
.
environ
.
get
(
"NATS_SERVER"
)
orig_etcd
=
os
.
environ
.
get
(
"ETCD_ENDPOINTS"
)
# Set environment variables for this test's dynamic ports
os
.
environ
[
"NATS_SERVER"
]
=
f
"nats://localhost:
{
nats_process
.
port
}
"
os
.
environ
[
"ETCD_ENDPOINTS"
]
=
f
"http://localhost:
{
etcd_process
.
port
}
"
yield
nats_process
,
etcd_process
# No test should rely on these variables after the test, but clean up just in case.
# Restore original env vars (or remove if they weren't set)
if
orig_nats
is
not
None
:
os
.
environ
[
"NATS_SERVER"
]
=
orig_nats
else
:
os
.
environ
.
pop
(
"NATS_SERVER"
,
None
)
if
orig_etcd
is
not
None
:
os
.
environ
[
"ETCD_ENDPOINTS"
]
=
orig_etcd
else
:
os
.
environ
.
pop
(
"ETCD_ENDPOINTS"
,
None
)
elif
request_plane
==
"nats"
:
with
NatsServer
(
request
,
port
=
0
)
as
nats_process
:
with
NatsServer
(
request
,
port
=
0
,
disable_jetstream
=
use_nats_core
)
as
nats_process
:
orig_nats
=
os
.
environ
.
get
(
"NATS_SERVER"
)
os
.
environ
[
"NATS_SERVER"
]
=
f
"nats://localhost:
{
nats_process
.
port
}
"
yield
nats_process
,
None
if
orig_nats
is
not
None
:
os
.
environ
[
"NATS_SERVER"
]
=
orig_nats
else
:
os
.
environ
.
pop
(
"NATS_SERVER"
,
None
)
else
:
yield
None
,
None
...
...
@@ -599,22 +689,28 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane):
def
runtime_services_session
(
request
,
tmp_path_factory
):
"""Session-scoped fixture that provides shared NATS and etcd instances for all tests.
Uses file-based reference counting to coordinate between pytest-xdist worker processes.
Only the first worker starts services, and only the last worker tears them down.
Uses file locking to coordinate between pytest-xdist worker processes.
First worker starts services on dynamic ports, subsequent workers reuse them.
Services are never torn down (live until container dies) to avoid race conditions.
WARNING: may not be parallel/xdist safe.
- This fixture shares one NATS + one etcd across many tests (and across xdist workers).
- It is only safe if tests fully isolate state (e.g. unique namespaces) and do not
assume exclusive access to global streams/keys/ports.
- Prefer `runtime_services_dynamic_ports` for true per-test isolation in parallel runs.
This fixture is xdist-safe when tests use unique namespaces (e.g. random suffixes)
and do not assume exclusive access to global streams/keys.
TODO: once nothing
use
s
`runtime_services_
session`, make the per-test
dynamic
ports
behavior the default for router/fron
te
n
d in
tegration test
s.
For tests that need to restart NATS (e.g. indexer sync),
use `runtime_services_dynamic
_
ports
`
which provides per-test isola
ted in
stance
s.
"""
with
SharedNatsServer
(
request
,
tmp_path_factory
)
as
nats
:
with
SharedEtcdServer
(
request
,
tmp_path_factory
)
as
etcd
:
# Set environment variables for Rust/Python runtime to use
os
.
environ
[
"NATS_SERVER"
]
=
f
"nats://localhost:
{
nats
.
port
}
"
os
.
environ
[
"ETCD_ENDPOINTS"
]
=
f
"http://localhost:
{
etcd
.
port
}
"
yield
nats
,
etcd
# Clean up environment variables
os
.
environ
.
pop
(
"NATS_SERVER"
,
None
)
os
.
environ
.
pop
(
"ETCD_ENDPOINTS"
,
None
)
@
pytest
.
fixture
def
file_storage_backend
():
...
...
tests/router/common.py
View file @
584020f4
...
...
@@ -525,10 +525,10 @@ async def send_request_via_python_kv_router(
stream
=
await
kv_python_router
.
generate
(
token_ids
=
token_ids
,
model
=
model_name
,
stop_conditions
=
stop_conditions
,
sampling_options
=
sampling_options
,
output_options
=
output_options
,
router_config_override
=
router_config_override
,
stop_conditions
=
stop_conditions
,
# type: ignore[arg-type]
sampling_options
=
sampling_options
,
# type: ignore[arg-type]
output_options
=
output_options
,
# type: ignore[arg-type]
router_config_override
=
router_config_override
,
# type: ignore[arg-type]
worker_id
=
worker_id
,
dp_rank
=
dp_rank
,
)
...
...
@@ -693,6 +693,7 @@ def _test_router_two_routers(
test_payload
:
dict
,
num_requests
:
int
,
store_backend
:
str
=
"etcd"
,
skip_consumer_verification
:
bool
=
False
,
):
"""Test two KV routers with alternating requests and consumer lifecycle verification.
...
...
@@ -701,8 +702,8 @@ def _test_router_two_routers(
This test:
1. Starts two KV routers on different ports
2. Sends requests alternating between the two routers
3. Verifies that both routers create durable consumers
4. Verifies consumers are cleaned up when routers exit
3. Verifies that both routers create durable consumers
(unless skipped)
4. Verifies consumers are cleaned up when routers exit
(unless skipped)
Args:
engine_workers: Backend workers (mocker/vllm) already initialized with __enter__()
...
...
@@ -712,6 +713,7 @@ def _test_router_two_routers(
test_payload: Test payload to send to /v1/chat/completions
num_requests: Number of concurrent requests to send
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
skip_consumer_verification: Skip JetStream consumer verification (for NATS Core mode).
Raises:
AssertionError: If consumer lifecycle verification fails
...
...
@@ -846,7 +848,13 @@ def _test_router_two_routers(
finally
:
await
nc
.
close
()
# Run consumer lifecycle verification
# Run consumer lifecycle verification (skip for NATS Core mode)
if
skip_consumer_verification
:
logger
.
info
(
"Skipping JetStream consumer verification (NATS Core mode)"
)
# Clean up routers manually since we're not doing consumer verification
for
kv_router
in
kv_routers
:
kv_router
.
__exit__
(
None
,
None
,
None
)
else
:
asyncio
.
run
(
verify_consumer_lifecycle
())
# Clear the kv_routers list since we've already cleaned them up
...
...
tests/router/test_router_e2e_with_mockers.py
View file @
584020f4
...
...
@@ -323,8 +323,15 @@ class DisaggMockerProcess:
@
pytest
.
mark
.
timeout
(
42
)
# ~3x average (~13.80s), rounded up
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"use_nats_core"
,
[
True
],
indirect
=
True
)
# Use NATS Core (local indexer)
def
test_mocker_kv_router
(
request
,
runtime_services_dynamic_ports
,
predownload_tokenizers
,
request_plane
request
,
runtime_services_dynamic_ports
,
predownload_tokenizers
,
request_plane
,
use_nats_core
,
):
"""
Test KV router with multiple mocker engine instances.
...
...
@@ -335,8 +342,12 @@ def test_mocker_kv_router(
# runtime_services starts etcd and optionally nats based on request_plane
logger
.
info
(
f
"Starting mocker KV router test with request_plane=
{
request_plane
}
"
)
# Create mocker args dictionary
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
}
# Create mocker args dictionary - use local indexer (NATS Core mode)
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
,
"enable_local_indexer"
:
use_nats_core
,
}
try
:
# Start mocker instances with the new CLI interface
...
...
@@ -372,6 +383,9 @@ def test_mocker_kv_router(
@
pytest
.
mark
.
parametrize
(
"store_backend"
,
[
"etcd"
,
"file"
])
@
pytest
.
mark
.
parametrize
(
"use_nats_core"
,
[
True
],
indirect
=
True
)
# Use NATS Core (local indexer)
@
pytest
.
mark
.
timeout
(
60
)
# ~3x average (~19.86s), rounded up
def
test_mocker_two_kv_router
(
request
,
...
...
@@ -379,6 +393,7 @@ def test_mocker_two_kv_router(
predownload_tokenizers
,
file_storage_backend
,
store_backend
,
use_nats_core
,
):
"""
Test with two KV routers and multiple mocker engine instances.
...
...
@@ -391,8 +406,12 @@ def test_mocker_two_kv_router(
f
"Starting mocker two KV router test with
{
store_backend
}
storage backend"
)
# Create mocker args dictionary
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
}
# Create mocker args dictionary - use local indexer (NATS Core mode)
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
,
"enable_local_indexer"
:
use_nats_core
,
}
try
:
# Start mocker instances with the new CLI interface
...
...
@@ -420,6 +439,7 @@ def test_mocker_two_kv_router(
test_payload
=
TEST_PAYLOAD
,
num_requests
=
NUM_REQUESTS
,
store_backend
=
store_backend
,
skip_consumer_verification
=
use_nats_core
,
# Skip JetStream checks in NATS Core mode
)
finally
:
...
...
@@ -428,17 +448,21 @@ def test_mocker_two_kv_router(
@
pytest
.
mark
.
skip
(
reason
=
"Flaky, temporarily disabled"
)
@
pytest
.
mark
.
parametrize
(
"use_nats_core"
,
[
True
],
indirect
=
True
)
# Use NATS Core (local indexer)
@
pytest
.
mark
.
timeout
(
60
)
# ~3x average (~19.86s), rounded up (when enabled)
def
test_mocker_kv_router_overload_503
(
request
,
runtime_services_dynamic_ports
,
predownload_tokenizers
request
,
runtime_services_dynamic_ports
,
predownload_tokenizers
,
use_nats_core
):
"""Test that KV router returns 503 when mocker workers are overloaded."""
logger
.
info
(
"Starting mocker KV router overload test for 503 status"
)
# Create mocker args dictionary with limited resources
# Create mocker args dictionary with limited resources
- use local indexer (NATS Core mode)
mocker_args
=
{
"speedup_ratio"
:
10
,
"block_size"
:
4
,
# Smaller block size
"num_gpu_blocks"
:
64
,
# Limited GPU blocks to exhaust quickly
"enable_local_indexer"
:
use_nats_core
,
}
try
:
...
...
@@ -468,12 +492,24 @@ def test_mocker_kv_router_overload_503(
@
pytest
.
mark
.
timeout
(
22
)
# ~3x average (~7.10s), rounded up
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"use_nats_core"
,
[
True
],
indirect
=
True
)
# Use NATS Core (local indexer)
def
test_kv_push_router_bindings
(
request
,
runtime_services_dynamic_ports
,
predownload_tokenizers
,
request_plane
request
,
runtime_services_dynamic_ports
,
predownload_tokenizers
,
request_plane
,
use_nats_core
,
):
"""Test KvPushRouter Python bindings with mocker engines."""
logger
.
info
(
"Starting KvPushRouter bindings test"
)
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
}
# Use local indexer (NATS Core mode)
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
,
"enable_local_indexer"
:
use_nats_core
,
}
try
:
# Start mocker instances
...
...
@@ -507,19 +543,19 @@ def test_kv_push_router_bindings(
mockers
.
__exit__
(
None
,
None
,
None
)
# NO @pytest.mark.parallel - nats_core variant stops/restarts NATS
@
pytest
.
mark
.
parametrize
(
"store_backend,use_nats_core,request_plane"
,
[
(
"etcd"
,
False
,
"nats"
),
# JetStream mode
(
"etcd"
,
True
,
"tcp"
),
# NATS core mode (with gap detection)
(
"file"
,
False
,
"nats"
),
# File backend
(
"etcd"
,
False
,
"nats"
),
# JetStream mode
- uses JetStream (default)
(
"etcd"
,
True
,
"tcp"
),
# NATS core mode (with gap detection)
- no JetStream
(
"file"
,
False
,
"nats"
),
# File backend
- uses JetStream (default)
],
ids
=
[
"jetstream"
,
"nats_core"
,
"file"
,
],
indirect
=
[
"request_plane"
,
"use_nats_core"
],
)
@
pytest
.
mark
.
timeout
(
90
)
# TODO: figure out a timeout
def
test_indexers_sync
(
...
...
@@ -590,12 +626,20 @@ def test_indexers_sync(
@
pytest
.
mark
.
timeout
(
42
)
# ~3x average (~13.80s), rounded up
@
pytest
.
mark
.
parametrize
(
"use_nats_core"
,
[
True
],
indirect
=
True
)
# Use NATS Core (local indexer)
def
test_query_instance_id_returns_worker_and_tokens
(
request
,
runtime_services_dynamic_ports
,
predownload_tokenizers
request
,
runtime_services_dynamic_ports
,
predownload_tokenizers
,
use_nats_core
):
"""Test query_instance_id annotation with mocker engines."""
logger
.
info
(
"Starting KV router query_instance_id annotation test"
)
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
}
# Use local indexer (NATS Core mode)
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
,
"enable_local_indexer"
:
use_nats_core
,
}
os
.
makedirs
(
request
.
node
.
name
,
exist_ok
=
True
)
try
:
...
...
@@ -629,11 +673,12 @@ def test_query_instance_id_returns_worker_and_tokens(
@
pytest
.
mark
.
parametrize
(
"use_nats_core,use_kv_events"
,
[
(
False
,
True
),
# JetStream mode (default)
(
True
,
True
),
# NATS Core + local indexer mode
(
False
,
False
),
# Approximate mode (--no-kv-events)
(
False
,
True
),
# JetStream mode (default)
- uses JetStream
(
True
,
True
),
# NATS Core + local indexer mode
- no JetStream
(
False
,
False
),
# Approximate mode (--no-kv-events)
- uses JetStream
],
ids
=
[
"jetstream"
,
"nats_core"
,
"no_kv_events"
],
indirect
=
[
"use_nats_core"
],
)
def
test_router_decisions
(
request
,
...
...
@@ -828,9 +873,16 @@ def test_router_decisions_disagg(
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"use_nats_core"
,
[
True
],
indirect
=
True
)
# Use NATS Core (local indexer)
@
pytest
.
mark
.
timeout
(
39
)
# ~3x average (~12.84s), rounded up
def
test_busy_threshold_endpoint
(
request
,
runtime_services_dynamic_ports
,
predownload_tokenizers
,
request_plane
request
,
runtime_services_dynamic_ports
,
predownload_tokenizers
,
request_plane
,
use_nats_core
,
):
"""Test that the /busy_threshold endpoint can be hit and responds correctly.
...
...
@@ -846,7 +898,12 @@ def test_busy_threshold_endpoint(
f
"Starting busy_threshold endpoint test with request_plane=
{
request_plane
}
"
)
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
}
# Use local indexer (NATS Core mode)
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
,
"enable_local_indexer"
:
use_nats_core
,
}
try
:
logger
.
info
(
f
"Starting
{
NUM_MOCKERS
}
mocker instances"
)
...
...
tests/router/test_router_e2e_with_sglang.py
View file @
584020f4
...
...
@@ -87,6 +87,7 @@ class SGLangProcess:
data_parallel_size
:
Optional
[
int
]
=
None
,
request_plane
:
str
=
"tcp"
,
store_backend
:
str
=
"etcd"
,
enable_local_indexer
:
bool
=
False
,
):
"""Initialize SGLang workers with dynamo integration.
...
...
@@ -103,6 +104,7 @@ class SGLangProcess:
data_parallel_size: If set, enables data parallelism with this many ranks (num_workers must equal data_parallel_size)
request_plane: Request plane to use ("nats", "tcp", or "http"). Defaults to "tcp".
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
enable_local_indexer: If True, enable worker-local KV indexer for NATS Core mode. Defaults to False.
"""
# Generate unique namespace for isolation
namespace_suffix
=
generate_random_suffix
()
...
...
@@ -192,6 +194,10 @@ class SGLangProcess:
if
self
.
store_backend
==
"file"
and
"DYN_FILE_KV"
in
os
.
environ
:
env_vars
[
"DYN_FILE_KV"
]
=
os
.
environ
[
"DYN_FILE_KV"
]
# Enable local indexer for NATS Core mode
if
enable_local_indexer
:
env_vars
[
"DYN_LOCAL_INDEXER"
]
=
"true"
env
.
update
(
env_vars
)
# Create managed process for the worker
...
...
@@ -313,7 +319,7 @@ class SGLangProcess:
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
timeout
(
150
)
# ~3x average (~46s/test), rounded up
def
test_sglang_kv_router_basic
(
request
,
...
...
@@ -369,7 +375,7 @@ def test_sglang_kv_router_basic(
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
skip
(
reason
=
"Broken by sglang changes"
)
# TODO: Re-enable this test once https://github.com/sgl-project/sglang/pull/14934 is merged
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"tcp"
],
indirect
=
True
)
def
test_router_decisions_sglang_multiple_workers
(
request
,
runtime_services_dynamic_ports
,
...
...
@@ -414,7 +420,7 @@ def test_router_decisions_sglang_multiple_workers(
@
pytest
.
mark
.
gpu_2
@
pytest
.
mark
.
post_merge
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
timeout
(
600
)
# 10 min max (multi-GPU + DP startup variance)
def
test_router_decisions_sglang_dp
(
request
,
...
...
@@ -469,10 +475,10 @@ def test_router_decisions_sglang_dp(
"store_backend,use_nats_core,request_plane"
,
[
(
"etcd"
,
False
,
"nats"
),
# JetStream mode
# ("etcd", True, "tcp"), #
ignored, needs unconditional nats_client
# ("etcd", True, "tcp"), #
nats_core mode - disabled for now
# ("file", False, "nats"), # File backend - TODO: investigate file backend support for SGLang
],
ids
=
[
"jetstream"
],
# "nats_core" and "file" commented out
ids
=
[
"jetstream"
],
)
@
pytest
.
mark
.
timeout
(
150
)
# ~3x average (~46s/test), rounded up
def
test_sglang_indexers_sync
(
...
...
@@ -491,8 +497,11 @@ def test_sglang_indexers_sync(
Tests with configuration:
- jetstream: etcd backend, JetStream for KV events, NATS request plane
- tcp_nats_core: etcd backend, local indexer with NATS Core, TCP request plane
"""
# runtime_services_dynamic_ports handles NATS and etcd startup
nats_process
,
_etcd_process
=
runtime_services_dynamic_ports
logger
.
info
(
f
"Starting SGLang indexers sync test: store_backend=
{
store_backend
}
, "
f
"use_nats_core=
{
use_nats_core
}
, request_plane=
{
request_plane
}
"
...
...
@@ -510,6 +519,7 @@ def test_sglang_indexers_sync(
single_gpu
=
True
,
# fit workers into one GPU
request_plane
=
request_plane
,
store_backend
=
store_backend
,
enable_local_indexer
=
use_nats_core
,
)
logger
.
info
(
f
"All SGLang workers using namespace:
{
sglang_workers
.
namespace
}
"
)
sglang_workers
.
__enter__
()
...
...
@@ -523,6 +533,8 @@ def test_sglang_indexers_sync(
num_workers
=
N_SGLANG_WORKERS
,
store_backend
=
store_backend
,
request_plane
=
request_plane
,
test_nats_interruption
=
use_nats_core
,
nats_server
=
nats_process
if
use_nats_core
else
None
,
)
logger
.
info
(
"SGLang indexers sync test completed successfully"
)
...
...
tests/router/test_router_e2e_with_trtllm.py
View file @
584020f4
...
...
@@ -84,6 +84,7 @@ class TRTLLMProcess:
single_gpu
:
bool
=
False
,
request_plane
:
str
=
"tcp"
,
store_backend
:
str
=
"etcd"
,
enable_local_indexer
:
bool
=
False
,
):
"""Initialize TRT-LLM workers with dynamo integration.
...
...
@@ -98,6 +99,7 @@ class TRTLLMProcess:
single_gpu: If True, all workers share GPU 0
request_plane: Request plane to use ("nats", "tcp", or "http"). Defaults to "tcp".
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
enable_local_indexer: If True, enable worker-local KV indexer for NATS Core mode. Defaults to False.
Note: TRT-LLM doesn't support data parallelism like vLLM (dp_rank is always 0).
Tensor parallelism (TP) is supported but creates 1 worker spanning multiple GPUs,
...
...
@@ -172,6 +174,10 @@ class TRTLLMProcess:
if
self
.
store_backend
==
"file"
and
"DYN_FILE_KV"
in
os
.
environ
:
env_vars
[
"DYN_FILE_KV"
]
=
os
.
environ
[
"DYN_FILE_KV"
]
# Enable local indexer for NATS Core mode
if
enable_local_indexer
:
env_vars
[
"DYN_LOCAL_INDEXER"
]
=
"true"
env
.
update
(
env_vars
)
# Create managed process for the worker
...
...
@@ -286,7 +292,7 @@ class TRTLLMProcess:
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
timeout
(
150
)
# ~3x average (~45s/test), rounded up
def
test_trtllm_kv_router_basic
(
request
,
...
...
@@ -340,7 +346,7 @@ def test_trtllm_kv_router_basic(
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
timeout
(
150
)
# ~3x average (~45s/test), rounded up
def
test_router_decisions_trtllm_multiple_workers
(
request
,
...
...
@@ -398,10 +404,10 @@ def test_router_decisions_trtllm_multiple_workers(
"store_backend,use_nats_core,request_plane"
,
[
(
"etcd"
,
False
,
"nats"
),
# JetStream mode
# ("etcd", True, "tcp"), #
ignored, needs unconditional nats_client
# ("etcd", True, "tcp"), #
nats_core mode - disabled for now
# ("file", False, "nats"), # File backend - TODO: investigate file backend support for TRT-LLM
],
ids
=
[
"jetstream"
],
# "nats_core" and "file" commented out
ids
=
[
"jetstream"
],
)
def
test_trtllm_indexers_sync
(
request
,
...
...
@@ -419,8 +425,11 @@ def test_trtllm_indexers_sync(
Tests with configuration:
- jetstream: etcd backend, JetStream for KV events, NATS request plane
- tcp_nats_core: etcd backend, local indexer with NATS Core, TCP request plane
"""
# runtime_services_dynamic_ports handles NATS and etcd startup
nats_process
,
_etcd_process
=
runtime_services_dynamic_ports
logger
.
info
(
f
"Starting TRT-LLM indexers sync test: store_backend=
{
store_backend
}
, "
f
"use_nats_core=
{
use_nats_core
}
, request_plane=
{
request_plane
}
"
...
...
@@ -438,6 +447,7 @@ def test_trtllm_indexers_sync(
single_gpu
=
True
,
# fit workers into one GPU
request_plane
=
request_plane
,
store_backend
=
store_backend
,
enable_local_indexer
=
use_nats_core
,
)
logger
.
info
(
f
"All TRT-LLM workers using namespace:
{
trtllm_workers
.
namespace
}
"
)
trtllm_workers
.
__enter__
()
...
...
@@ -451,6 +461,8 @@ def test_trtllm_indexers_sync(
num_workers
=
N_TRTLLM_WORKERS
,
store_backend
=
store_backend
,
request_plane
=
request_plane
,
test_nats_interruption
=
use_nats_core
,
nats_server
=
nats_process
if
use_nats_core
else
None
,
)
logger
.
info
(
"TRT-LLM indexers sync test completed successfully"
)
...
...
tests/router/test_router_e2e_with_vllm.py
View file @
584020f4
...
...
@@ -87,6 +87,7 @@ class VLLMProcess:
data_parallel_size
:
Optional
[
int
]
=
None
,
request_plane
:
str
=
"tcp"
,
store_backend
:
str
=
"etcd"
,
enable_local_indexer
:
bool
=
False
,
):
"""Initialize vLLM workers with dynamo integration.
...
...
@@ -104,6 +105,7 @@ class VLLMProcess:
data_parallel_size: If set, enables data parallelism with this many ranks (num_workers must equal data_parallel_size)
request_plane: Request plane to use ("nats", "tcp", or "http"). Defaults to "tcp".
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
enable_local_indexer: If True, enable worker-local KV indexer for NATS Core mode. Defaults to False.
"""
# Generate unique namespace for isolation
namespace_suffix
=
generate_random_suffix
()
...
...
@@ -209,6 +211,10 @@ class VLLMProcess:
if
self
.
store_backend
==
"file"
and
"DYN_FILE_KV"
in
os
.
environ
:
env_vars
[
"DYN_FILE_KV"
]
=
os
.
environ
[
"DYN_FILE_KV"
]
# Enable local indexer for NATS Core mode
if
enable_local_indexer
:
env_vars
[
"DYN_LOCAL_INDEXER"
]
=
"true"
env
.
update
(
env_vars
)
# Create managed process for the worker
...
...
@@ -329,7 +335,7 @@ class VLLMProcess:
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
timeout
(
150
)
# ~3x average (~43s/test), rounded up
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"tcp"
],
indirect
=
True
)
def
test_vllm_kv_router_basic
(
request
,
runtime_services_dynamic_ports
,
...
...
@@ -383,7 +389,7 @@ def test_vllm_kv_router_basic(
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
timeout
(
150
)
# ~3x average (~43s/test), rounded up
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"tcp"
],
indirect
=
True
)
def
test_router_decisions_vllm_multiple_workers
(
request
,
runtime_services_dynamic_ports
,
...
...
@@ -428,7 +434,7 @@ def test_router_decisions_vllm_multiple_workers(
@
pytest
.
mark
.
gpu_2
@
pytest
.
mark
.
nightly
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
timeout
(
600
)
# 10 min max (multi-GPU + DP startup variance)
def
test_router_decisions_vllm_dp
(
request
,
...
...
@@ -484,10 +490,10 @@ def test_router_decisions_vllm_dp(
"store_backend,use_nats_core,request_plane"
,
[
(
"etcd"
,
False
,
"nats"
),
# JetStream mode
(
"etcd"
,
True
,
"tcp"
),
# nats_core mode
#
("etcd", True, "tcp"), # nats_core mode
- disabled for now
# ("file", False, "nats"), # File backend
],
ids
=
[
"jetstream"
,
"tcp_nats_core"
],
ids
=
[
"jetstream"
],
)
def
test_vllm_indexers_sync
(
request
,
...
...
@@ -508,6 +514,8 @@ def test_vllm_indexers_sync(
- tcp_nats_core: etcd backend, local indexer with NATS Core, TCP request plane
"""
# runtime_services_dynamic_ports handles NATS and etcd startup
nats_process
,
_etcd_process
=
runtime_services_dynamic_ports
logger
.
info
(
f
"Starting vLLM indexers sync test: store_backend=
{
store_backend
}
, "
f
"use_nats_core=
{
use_nats_core
}
, request_plane=
{
request_plane
}
"
...
...
@@ -525,6 +533,7 @@ def test_vllm_indexers_sync(
single_gpu
=
True
,
# fit workers into one GPU
request_plane
=
request_plane
,
store_backend
=
store_backend
,
enable_local_indexer
=
use_nats_core
,
)
logger
.
info
(
f
"All vLLM workers using namespace:
{
vllm_workers
.
namespace
}
"
)
vllm_workers
.
__enter__
()
...
...
@@ -538,6 +547,8 @@ def test_vllm_indexers_sync(
num_workers
=
N_VLLM_WORKERS
,
store_backend
=
store_backend
,
request_plane
=
request_plane
,
test_nats_interruption
=
use_nats_core
,
nats_server
=
nats_process
if
use_nats_core
else
None
,
)
logger
.
info
(
"vLLM indexers sync test completed successfully"
)
...
...
tests/utils/managed_process.py
View file @
584020f4
...
...
@@ -35,7 +35,7 @@ def terminate_process(process, logger=logging.getLogger(), immediate_kill=False)
def
terminate_process_tree
(
pid
,
logger
=
logging
.
getLogger
(),
immediate_kill
=
False
,
timeout
=
10
pid
,
logger
=
logging
.
getLogger
(),
immediate_kill
=
False
,
timeout
=
2
):
try
:
parent
=
psutil
.
Process
(
pid
)
...
...
@@ -277,7 +277,7 @@ class ManagedProcess:
)
self
.
_tee_proc
=
None
def
_terminate_process_group
(
self
,
timeout
:
float
=
5
.0
):
def
_terminate_process_group
(
self
,
timeout
:
float
=
2
.0
):
"""Terminate the entire process group/session started for the child.
This catches cases where the launcher shell exits and its children are reparented,
...
...
@@ -296,10 +296,23 @@ class ManagedProcess:
)
return
# Give processes a brief moment to exit gracefully
time
.
sleep
(
timeout
)
# Poll for process exit instead of fixed sleep to minimize teardown time
poll_interval
=
0.1
elapsed
=
0.0
while
elapsed
<
timeout
:
try
:
# Check if any process in the group is still alive
os
.
killpg
(
self
.
_pgid
,
0
)
# Signal 0 = check existence
except
ProcessLookupError
:
# Process group no longer exists - done
return
except
Exception
:
# Other errors (e.g., permission) - assume done
return
time
.
sleep
(
poll_interval
)
elapsed
+=
poll_interval
# Force kill if anything remains
# Force kill if anything remains
after timeout
try
:
os
.
killpg
(
self
.
_pgid
,
signal
.
SIGKILL
)
except
ProcessLookupError
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment