Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
584020f4
Unverified
Commit
584020f4
authored
Jan 29, 2026
by
Yan Ru Pei
Committed by
GitHub
Jan 29, 2026
Browse files
chore: consistently used the arc shared workers configs + CI flake fixes (#5707)
Signed-off-by:
PeaBrane
<
yanrpei@gmail.com
>
parent
1174c819
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
729 additions
and
404 deletions
+729
-404
lib/llm/src/discovery.rs
lib/llm/src/discovery.rs
+4
-1
lib/llm/src/discovery/model_manager.rs
lib/llm/src/discovery/model_manager.rs
+11
-140
lib/llm/src/discovery/runtime_configs.rs
lib/llm/src/discovery/runtime_configs.rs
+200
-0
lib/llm/src/hub.rs
lib/llm/src/hub.rs
+64
-0
lib/llm/src/kv_router.rs
lib/llm/src/kv_router.rs
+18
-35
lib/llm/src/kv_router/publisher.rs
lib/llm/src/kv_router/publisher.rs
+6
-9
lib/llm/src/kv_router/scheduler.rs
lib/llm/src/kv_router/scheduler.rs
+14
-8
lib/llm/src/kv_router/subscriber.rs
lib/llm/src/kv_router/subscriber.rs
+30
-39
lib/llm/src/kv_router/worker_query.rs
lib/llm/src/kv_router/worker_query.rs
+17
-15
lib/runtime/src/config/environment_names.rs
lib/runtime/src/config/environment_names.rs
+5
-0
lib/runtime/src/transports/nats.rs
lib/runtime/src/transports/nats.rs
+0
-6
tests/conftest.py
tests/conftest.py
+199
-103
tests/router/common.py
tests/router/common.py
+16
-8
tests/router/test_router_e2e_with_mockers.py
tests/router/test_router_e2e_with_mockers.py
+77
-20
tests/router/test_router_e2e_with_sglang.py
tests/router/test_router_e2e_with_sglang.py
+17
-5
tests/router/test_router_e2e_with_trtllm.py
tests/router/test_router_e2e_with_trtllm.py
+16
-4
tests/router/test_router_e2e_with_vllm.py
tests/router/test_router_e2e_with_vllm.py
+16
-5
tests/utils/managed_process.py
tests/utils/managed_process.py
+19
-6
No files found.
lib/llm/src/discovery.rs
View file @
584020f4
...
@@ -2,7 +2,10 @@
...
@@ -2,7 +2,10 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-License-Identifier: Apache-2.0
mod
model_manager
;
mod
model_manager
;
pub
use
model_manager
::{
ModelManager
,
ModelManagerError
,
RuntimeConfigsWithNotify
};
pub
use
model_manager
::{
ModelManager
,
ModelManagerError
};
pub
(
crate
)
mod
runtime_configs
;
pub
use
runtime_configs
::{
RuntimeConfigs
,
RuntimeConfigsSubscriber
};
mod
watcher
;
mod
watcher
;
pub
use
watcher
::{
ModelUpdate
,
ModelWatcher
};
pub
use
watcher
::{
ModelUpdate
,
ModelWatcher
};
...
...
lib/llm/src/discovery/model_manager.rs
View file @
584020f4
...
@@ -8,13 +8,14 @@ use std::{
...
@@ -8,13 +8,14 @@ use std::{
use
dashmap
::{
DashMap
,
mapref
::
entry
::
Entry
};
use
dashmap
::{
DashMap
,
mapref
::
entry
::
Entry
};
use
parking_lot
::{
Mutex
,
RwLock
};
use
parking_lot
::{
Mutex
,
RwLock
};
use
tokio
::
sync
::
{
Notify
,
oneshot
}
;
use
tokio
::
sync
::
oneshot
;
use
crate
::
discovery
::
KvWorkerMonitor
;
use
crate
::
discovery
::
KvWorkerMonitor
;
use
crate
::
discovery
::
runtime_configs
::
RuntimeConfigs
;
use
dynamo_runtime
::{
use
dynamo_runtime
::{
component
::{
Client
,
Endpoint
,
build_transport_type
},
component
::{
Client
,
Endpoint
,
build_transport_type
},
discovery
::
{
DiscoveryQuery
,
DiscoverySpec
,
watch_and_extract_field
}
,
discovery
::
DiscoverySpec
,
prelude
::
DistributedRuntimeProvider
,
prelude
::
DistributedRuntimeProvider
,
protocols
::
EndpointId
,
protocols
::
EndpointId
,
};
};
...
@@ -24,7 +25,7 @@ use crate::{
...
@@ -24,7 +25,7 @@ use crate::{
KvRouter
,
KvRouterConfig
,
protocols
::
WorkerId
,
router_endpoint_id
,
KvRouter
,
KvRouterConfig
,
protocols
::
WorkerId
,
router_endpoint_id
,
scheduler
::
DefaultWorkerSelector
,
scheduler
::
DefaultWorkerSelector
,
},
},
local_model
::
runtime_config
::
{
DisaggregatedEndpoint
,
ModelRuntimeConfig
},
local_model
::
runtime_config
::
DisaggregatedEndpoint
,
model_card
::
ModelDeploymentCard
,
model_card
::
ModelDeploymentCard
,
model_type
::
ModelType
,
model_type
::
ModelType
,
types
::{
types
::{
...
@@ -81,14 +82,8 @@ pub struct ModelManager {
...
@@ -81,14 +82,8 @@ pub struct ModelManager {
/// Runtime configs per endpoint using DashMap for lock-free access.
/// Runtime configs per endpoint using DashMap for lock-free access.
/// Outer DashMap: keyed by EndpointId
/// Outer DashMap: keyed by EndpointId
/// Inner RuntimeConfigsWithNotify: shared with KvScheduler
/// Inner RuntimeConfigs: shared with KvScheduler
runtime_configs
:
DashMap
<
EndpointId
,
Arc
<
RuntimeConfigsWithNotify
>>
,
runtime_configs
:
DashMap
<
EndpointId
,
Arc
<
RuntimeConfigs
>>
,
}
/// Runtime configs for an endpoint with a notify for change notifications.
pub
struct
RuntimeConfigsWithNotify
{
pub
configs
:
DashMap
<
WorkerId
,
Option
<
ModelRuntimeConfig
>>
,
pub
notify
:
Notify
,
}
}
impl
Default
for
ModelManager
{
impl
Default
for
ModelManager
{
...
@@ -621,12 +616,12 @@ impl ModelManager {
...
@@ -621,12 +616,12 @@ impl ModelManager {
}
}
/// Get or create a runtime config watcher for an endpoint.
/// Get or create a runtime config watcher for an endpoint.
/// Spawns a background task to watch
DiscoveryQuery::EndpointModel
s.
/// Spawns a background task to watch
for worker config change
s.
/// Returns a shared RuntimeConfigs
WithNotify
that KvScheduler can use directly.
/// Returns a shared RuntimeConfigs that KvScheduler can use directly.
pub
async
fn
get_or_create_runtime_config_watcher
(
pub
async
fn
get_or_create_runtime_config_watcher
(
&
self
,
&
self
,
endpoint
:
&
Endpoint
,
endpoint
:
&
Endpoint
,
)
->
anyhow
::
Result
<
Arc
<
RuntimeConfigs
WithNotify
>>
{
)
->
anyhow
::
Result
<
Arc
<
RuntimeConfigs
>>
{
let
endpoint_id
=
endpoint
.id
();
let
endpoint_id
=
endpoint
.id
();
// Fast path: return existing if present
// Fast path: return existing if present
...
@@ -635,10 +630,7 @@ impl ModelManager {
...
@@ -635,10 +630,7 @@ impl ModelManager {
}
}
// Atomic get-or-insert to avoid TOCTOU race
// Atomic get-or-insert to avoid TOCTOU race
let
inner
=
Arc
::
new
(
RuntimeConfigsWithNotify
{
let
inner
=
Arc
::
new
(
RuntimeConfigs
::
new
());
configs
:
DashMap
::
new
(),
notify
:
Notify
::
new
(),
});
let
(
result
,
is_new
)
=
match
self
.runtime_configs
.entry
(
endpoint_id
)
{
let
(
result
,
is_new
)
=
match
self
.runtime_configs
.entry
(
endpoint_id
)
{
Entry
::
Occupied
(
e
)
=>
(
e
.get
()
.clone
(),
false
),
Entry
::
Occupied
(
e
)
=>
(
e
.get
()
.clone
(),
false
),
Entry
::
Vacant
(
e
)
=>
{
Entry
::
Vacant
(
e
)
=>
{
...
@@ -649,8 +641,7 @@ impl ModelManager {
...
@@ -649,8 +641,7 @@ impl ModelManager {
// Only spawn watcher if we were the one who inserted
// Only spawn watcher if we were the one who inserted
if
is_new
{
if
is_new
{
self
.spawn_runtime_config_watcher
(
endpoint
,
result
.clone
())
result
.start_watcher
(
endpoint
)
.await
?
;
.await
?
;
}
}
Ok
(
result
)
Ok
(
result
)
...
@@ -668,126 +659,6 @@ impl ModelManager {
...
@@ -668,126 +659,6 @@ impl ModelManager {
config_ref
.as_ref
()
?
.disaggregated_endpoint
.clone
()
config_ref
.as_ref
()
?
.disaggregated_endpoint
.clone
()
}
}
/// Spawn background task to watch runtime configs via discovery.
/// Blocks until at least one worker with a runtime config is available.
async
fn
spawn_runtime_config_watcher
(
&
self
,
endpoint
:
&
Endpoint
,
inner
:
Arc
<
RuntimeConfigsWithNotify
>
,
)
->
anyhow
::
Result
<
()
>
{
let
component
=
endpoint
.component
();
let
cancellation_token
=
component
.drt
()
.primary_token
();
// Set up discovery watch for EndpointModels
let
discovery
=
component
.drt
()
.discovery
();
let
endpoint_id
=
endpoint
.id
();
let
discovery_key
=
DiscoveryQuery
::
EndpointModels
{
namespace
:
endpoint_id
.namespace
.clone
(),
component
:
endpoint_id
.component
.clone
(),
endpoint
:
endpoint_id
.name
.clone
(),
};
let
discovery_stream
=
discovery
.list_and_watch
(
discovery_key
.clone
(),
Some
(
cancellation_token
.clone
()))
.await
?
;
// Extract runtime_config from ModelDeploymentCard
let
mut
runtime_configs_rx
=
watch_and_extract_field
(
discovery_stream
,
|
card
:
ModelDeploymentCard
|
{
card
.runtime_config
});
// Also watch instance IDs
let
client
=
endpoint
.client
()
.await
?
;
let
mut
instance_ids_rx
=
client
.instance_avail_watcher
();
// Wait for at least one worker with runtime config before proceeding.
// This ensures the DashMap is populated before KvScheduler starts.
tracing
::
info!
(
"ModelManager: Waiting for at least one worker with runtime config..."
);
runtime_configs_rx
.changed
()
.await
.map_err
(|
_
|
anyhow
::
anyhow!
(
"runtime configs watch sender shutdown while waiting"
))
?
;
// Populate initial state
{
let
instance_ids
=
instance_ids_rx
.borrow
();
let
configs
=
runtime_configs_rx
.borrow
();
for
worker_id
in
instance_ids
.iter
()
{
let
config
=
configs
.get
(
worker_id
)
.cloned
();
inner
.configs
.insert
(
*
worker_id
,
config
);
}
tracing
::
info!
(
"ModelManager: Found {} workers, proceeding"
,
inner
.configs
.len
()
);
}
// Spawn background task to update configs for future changes
let
cancel_token
=
cancellation_token
.clone
();
tokio
::
spawn
(
async
move
{
tracing
::
trace!
(
"ModelManager runtime config watcher started"
);
loop
{
// Wait for either instances or configs to change
tokio
::
select!
{
_
=
cancel_token
.cancelled
()
=>
{
tracing
::
trace!
(
"ModelManager runtime config watcher shutting down"
);
break
;
}
result
=
instance_ids_rx
.changed
()
=>
{
if
result
.is_err
()
{
tracing
::
warn!
(
"instance IDs watch sender shutdown in ModelManager"
);
break
;
}
}
result
=
runtime_configs_rx
.changed
()
=>
{
if
result
.is_err
()
{
tracing
::
warn!
(
"runtime configs watch sender shutdown in ModelManager"
);
break
;
}
}
}
// Get the latest values from both channels
let
new_instance_ids
=
instance_ids_rx
.borrow_and_update
()
.clone
();
let
new_configs
=
runtime_configs_rx
.borrow_and_update
()
.clone
();
// Update the DashMap
// First, remove workers that no longer exist
let
current_workers
:
HashSet
<
WorkerId
>
=
inner
.configs
.iter
()
.map
(|
r
|
*
r
.key
())
.collect
();
let
new_workers
:
HashSet
<
WorkerId
>
=
new_instance_ids
.iter
()
.copied
()
.collect
();
for
removed_worker
in
current_workers
.difference
(
&
new_workers
)
{
inner
.configs
.remove
(
removed_worker
);
}
// Then, add/update workers
for
worker_id
in
&
new_instance_ids
{
let
config
=
new_configs
.get
(
worker_id
)
.cloned
();
if
config
.is_some
()
{
let
prev_config
=
inner
.configs
.get
(
worker_id
);
if
prev_config
.as_ref
()
.map
(|
r
|
r
.value
())
!=
Some
(
&
config
)
{
tracing
::
info!
(
"ModelManager: Runtime config found for worker_id: {worker_id}"
);
}
}
inner
.configs
.insert
(
*
worker_id
,
config
);
}
// Notify waiters that configs have changed
inner
.notify
.notify_waiters
();
tracing
::
trace!
(
"ModelManager: Updated runtime_configs with {} workers"
,
inner
.configs
.len
()
);
}
tracing
::
trace!
(
"ModelManager runtime config watcher shutting down"
);
});
Ok
(())
}
/// Lists all models that have worker monitors (and thus busy thresholds) configured.
/// Lists all models that have worker monitors (and thus busy thresholds) configured.
///
///
/// Returns a vector of (model_name, active_decode_blocks_threshold, active_prefill_tokens_threshold) tuples.
/// Returns a vector of (model_name, active_decode_blocks_threshold, active_prefill_tokens_threshold) tuples.
...
...
lib/llm/src/discovery/runtime_configs.rs
0 → 100644
View file @
584020f4
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use
std
::
collections
::{
HashMap
,
HashSet
};
use
std
::
sync
::
Arc
;
use
dashmap
::
DashMap
;
use
tokio
::
sync
::
watch
;
use
dynamo_runtime
::
component
::
Endpoint
;
use
dynamo_runtime
::
discovery
::{
DiscoveryQuery
,
watch_and_extract_field
};
use
dynamo_runtime
::
prelude
::
DistributedRuntimeProvider
;
use
crate
::
kv_router
::
protocols
::
WorkerId
;
use
crate
::
local_model
::
runtime_config
::
ModelRuntimeConfig
;
use
crate
::
model_card
::
ModelDeploymentCard
;
/// Runtime configs for an endpoint with watch-based change notifications.
/// Call `subscribe()` to get a subscriber with its own watch receiver.
pub
struct
RuntimeConfigs
{
pub
configs
:
Arc
<
DashMap
<
WorkerId
,
Option
<
ModelRuntimeConfig
>>>
,
change_tx
:
watch
::
Sender
<
u64
>
,
}
impl
RuntimeConfigs
{
pub
(
crate
)
fn
new
()
->
Self
{
let
(
change_tx
,
_
)
=
watch
::
channel
(
0u64
);
Self
{
configs
:
Arc
::
new
(
DashMap
::
new
()),
change_tx
,
}
}
/// Create a subscriber that can wait for config changes.
/// Each subscriber has its own watch receiver, so notifications are not lost.
pub
fn
subscribe
(
&
self
)
->
RuntimeConfigsSubscriber
{
RuntimeConfigsSubscriber
{
configs
:
self
.configs
.clone
(),
change_rx
:
self
.change_tx
.subscribe
(),
}
}
/// Notify all subscribers of a change (internal use only).
fn
notify_change
(
&
self
)
{
// Increment counter to notify subscribers
self
.change_tx
.send_modify
(|
v
|
*
v
=
v
.wrapping_add
(
1
));
}
/// Returns the number of workers in the configs.
pub
fn
num_workers
(
&
self
)
->
usize
{
self
.configs
.len
()
}
/// Update configs with new worker instances and their configs.
/// Notifies subscribers if a config with Some value is added or a worker is removed.
pub
(
crate
)
fn
update
(
&
self
,
new_instance_ids
:
&
[
WorkerId
],
new_configs
:
&
HashMap
<
WorkerId
,
ModelRuntimeConfig
>
,
)
{
// First, remove workers that no longer exist
let
current_workers
:
HashSet
<
WorkerId
>
=
self
.configs
.iter
()
.map
(|
r
|
*
r
.key
())
.collect
();
let
new_workers
:
HashSet
<
WorkerId
>
=
new_instance_ids
.iter
()
.copied
()
.collect
();
let
mut
worker_removed
=
false
;
for
removed_worker
in
current_workers
.difference
(
&
new_workers
)
{
self
.configs
.remove
(
removed_worker
);
worker_removed
=
true
;
}
// Then, add/update workers
// Track if any config became Some (for notify)
let
mut
config_added
=
false
;
for
worker_id
in
new_instance_ids
{
let
config
=
new_configs
.get
(
worker_id
)
.cloned
();
if
config
.is_some
()
{
let
prev_config
=
self
.configs
.get
(
worker_id
);
let
was_none
=
prev_config
.as_ref
()
.map
(|
r
|
r
.value
()
.is_none
())
.unwrap_or
(
true
);
if
was_none
{
tracing
::
info!
(
"RuntimeConfigs: config found for worker_id: {worker_id}"
);
config_added
=
true
;
}
}
self
.configs
.insert
(
*
worker_id
,
config
);
}
// Notify when a config with Some value is added OR a worker is removed
if
config_added
||
worker_removed
{
self
.notify_change
();
}
}
/// Spawn background task to watch runtime configs via discovery.
/// Does not block - consumers should use `subscribe().wait_for_some()` if they need workers.
pub
(
crate
)
async
fn
start_watcher
(
self
:
&
Arc
<
Self
>
,
endpoint
:
&
Endpoint
)
->
anyhow
::
Result
<
()
>
{
let
component
=
endpoint
.component
();
let
cancellation_token
=
component
.drt
()
.primary_token
();
// Set up discovery watch for EndpointModels
let
discovery
=
component
.drt
()
.discovery
();
let
endpoint_id
=
endpoint
.id
();
let
discovery_key
=
DiscoveryQuery
::
EndpointModels
{
namespace
:
endpoint_id
.namespace
.clone
(),
component
:
endpoint_id
.component
.clone
(),
endpoint
:
endpoint_id
.name
.clone
(),
};
let
discovery_stream
=
discovery
.list_and_watch
(
discovery_key
.clone
(),
Some
(
cancellation_token
.clone
()))
.await
?
;
// Extract runtime_config from ModelDeploymentCard
let
mut
runtime_configs_rx
=
watch_and_extract_field
(
discovery_stream
,
|
card
:
ModelDeploymentCard
|
{
card
.runtime_config
});
// Also watch instance IDs
let
client
=
endpoint
.client
()
.await
?
;
let
mut
instance_ids_rx
=
client
.instance_avail_watcher
();
// Spawn background task to watch for config changes
// Note: We don't block here - consumers should wait on notify for configs they need
let
inner
=
self
.clone
();
let
cancel_token
=
cancellation_token
.clone
();
tokio
::
spawn
(
async
move
{
tracing
::
trace!
(
"RuntimeConfigs watcher started"
);
loop
{
// Wait for either instances or configs to change
tokio
::
select!
{
_
=
cancel_token
.cancelled
()
=>
{
tracing
::
trace!
(
"RuntimeConfigs watcher shutting down"
);
break
;
}
result
=
instance_ids_rx
.changed
()
=>
{
if
result
.is_err
()
{
tracing
::
warn!
(
"instance IDs watch sender shutdown"
);
break
;
}
}
result
=
runtime_configs_rx
.changed
()
=>
{
if
result
.is_err
()
{
tracing
::
warn!
(
"runtime configs watch sender shutdown"
);
break
;
}
}
}
// Get the latest values from both channels
let
new_instance_ids
=
instance_ids_rx
.borrow_and_update
()
.clone
();
let
new_configs
=
runtime_configs_rx
.borrow_and_update
()
.clone
();
inner
.update
(
&
new_instance_ids
,
&
new_configs
);
tracing
::
trace!
(
"RuntimeConfigs: Updated with {} workers"
,
inner
.configs
.len
()
);
}
tracing
::
trace!
(
"RuntimeConfigs watcher stopped"
);
});
Ok
(())
}
}
/// A subscriber to runtime config changes.
/// Each subscriber has its own watch receiver, ensuring no notifications are lost.
pub
struct
RuntimeConfigsSubscriber
{
pub
configs
:
Arc
<
DashMap
<
WorkerId
,
Option
<
ModelRuntimeConfig
>>>
,
pub
change_rx
:
watch
::
Receiver
<
u64
>
,
}
impl
RuntimeConfigsSubscriber
{
/// Wait until at least one worker has a Some config.
/// Returns the list of worker IDs that have configs.
/// This is race-safe: checks the DashMap first, only waits if empty.
/// Returns empty vec if the sender is dropped (shutdown).
pub
async
fn
wait_for_some
(
&
mut
self
)
->
Vec
<
WorkerId
>
{
loop
{
let
ready
:
Vec
<
WorkerId
>
=
self
.configs
.iter
()
.filter
(|
r
|
r
.value
()
.is_some
())
.map
(|
r
|
*
r
.key
())
.collect
();
if
!
ready
.is_empty
()
{
return
ready
;
}
// If sender dropped (shutdown), return empty rather than loop forever
if
self
.change_rx
.changed
()
.await
.is_err
()
{
tracing
::
warn!
(
"RuntimeConfigsSubscriber: sender dropped during wait_for_some"
);
return
vec!
[];
}
}
}
}
lib/llm/src/hub.rs
View file @
584020f4
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
use
std
::
env
;
use
std
::
env
;
use
std
::
path
::{
Path
,
PathBuf
};
use
std
::
path
::{
Path
,
PathBuf
};
use
hf_hub
::
Cache
;
use
modelexpress_client
::{
use
modelexpress_client
::{
Client
as
MxClient
,
ClientConfig
as
MxClientConfig
,
ModelProvider
as
MxModelProvider
,
Client
as
MxClient
,
ClientConfig
as
MxClientConfig
,
ModelProvider
as
MxModelProvider
,
};
};
...
@@ -11,14 +12,77 @@ use modelexpress_common::download as mx;
...
@@ -11,14 +12,77 @@ use modelexpress_common::download as mx;
use
dynamo_runtime
::
config
::
environment_names
::
model
as
env_model
;
use
dynamo_runtime
::
config
::
environment_names
::
model
as
env_model
;
/// Check if a model is already cached in the HuggingFace hub cache directory.
/// Returns the path to the cached model directory if found, None otherwise.
///
/// Uses hf-hub's Cache API to check for cached files. For tokenizer-only downloads
/// (ignore_weights=true), we check for config.json and tokenizer files.
/// For full downloads, we also require weight files to be present.
fn
get_cached_model_path
(
model_name
:
&
str
,
ignore_weights
:
bool
)
->
Option
<
PathBuf
>
{
let
cache
=
Cache
::
new
(
get_model_express_cache_dir
());
let
repo
=
cache
.model
(
model_name
.to_string
());
// Check for required config file
let
config_path
=
repo
.get
(
"config.json"
)
?
;
// Check for tokenizer files (at least one must exist)
let
has_tokenizer
=
repo
.get
(
"tokenizer.json"
)
.is_some
()
||
repo
.get
(
"tokenizer_config.json"
)
.is_some
();
if
!
has_tokenizer
{
return
None
;
}
// For full downloads, check for weight files
if
!
ignore_weights
{
// Check common weight file patterns - at least one must exist
let
has_weights
=
repo
.get
(
"model.safetensors"
)
.is_some
()
||
repo
.get
(
"pytorch_model.bin"
)
.is_some
()
||
repo
.get
(
"model.safetensors.index.json"
)
.is_some
()
||
repo
.get
(
"pytorch_model.bin.index.json"
)
.is_some
();
if
!
has_weights
{
return
None
;
}
}
// Return the parent directory (snapshot dir) containing the model files
let
snapshot_path
=
config_path
.parent
()
?
.to_path_buf
();
tracing
::
info!
(
"Found cached model '{model_name}' at {snapshot_path:?}, skipping download"
);
Some
(
snapshot_path
)
}
/// Check if offline mode is enabled via HF_HUB_OFFLINE environment variable.
fn
is_offline_mode
()
->
bool
{
env
::
var
(
env_model
::
huggingface
::
HF_HUB_OFFLINE
)
.map
(|
v
|
v
==
"1"
||
v
.to_lowercase
()
==
"true"
)
.unwrap_or
(
false
)
}
/// Download a model using ModelExpress client. The client first requests for the model
/// Download a model using ModelExpress client. The client first requests for the model
/// from the server and fallbacks to direct download in case of server failure.
/// from the server and fallbacks to direct download in case of server failure.
/// If ignore_weights is true, model weight files will be skipped
/// If ignore_weights is true, model weight files will be skipped
/// Returns the path to the model files
/// Returns the path to the model files
///
/// If HF_HUB_OFFLINE=1 is set and the model is already cached, returns the cached
/// path without making any API calls to HuggingFace.
pub
async
fn
from_hf
(
name
:
impl
AsRef
<
Path
>
,
ignore_weights
:
bool
)
->
anyhow
::
Result
<
PathBuf
>
{
pub
async
fn
from_hf
(
name
:
impl
AsRef
<
Path
>
,
ignore_weights
:
bool
)
->
anyhow
::
Result
<
PathBuf
>
{
let
name
=
name
.as_ref
();
let
name
=
name
.as_ref
();
let
model_name
=
name
.display
()
.to_string
();
let
model_name
=
name
.display
()
.to_string
();
// In offline mode, check cache first and return immediately if found
if
is_offline_mode
()
{
if
let
Some
(
cached_path
)
=
get_cached_model_path
(
&
model_name
,
ignore_weights
)
{
tracing
::
info!
(
"Offline mode: using cached model '{model_name}' without API validation"
);
return
Ok
(
cached_path
);
}
tracing
::
warn!
(
"Offline mode enabled but model '{model_name}' not found in cache, attempting download anyway"
);
}
let
mut
config
:
MxClientConfig
=
MxClientConfig
::
default
();
let
mut
config
:
MxClientConfig
=
MxClientConfig
::
default
();
if
let
Ok
(
endpoint
)
=
env
::
var
(
env_model
::
model_express
::
MODEL_EXPRESS_URL
)
{
if
let
Ok
(
endpoint
)
=
env
::
var
(
env_model
::
model_express
::
MODEL_EXPRESS_URL
)
{
config
=
config
.with_endpoint
(
endpoint
);
config
=
config
.with_endpoint
(
endpoint
);
...
...
lib/llm/src/kv_router.rs
View file @
584020f4
...
@@ -9,7 +9,7 @@ use anyhow::Result;
...
@@ -9,7 +9,7 @@ use anyhow::Result;
use
derive_builder
::
Builder
;
use
derive_builder
::
Builder
;
use
dynamo_runtime
::{
use
dynamo_runtime
::{
component
::{
Client
,
Endpoint
},
component
::{
Client
,
Endpoint
},
discovery
::{
DiscoveryQuery
,
EventTransportKind
,
watch_and_extract_field
},
discovery
::{
DiscoveryQuery
,
EventTransportKind
},
pipeline
::{
pipeline
::{
AsyncEngine
,
AsyncEngineContextProvider
,
Error
,
ManyOut
,
PushRouter
,
ResponseStream
,
AsyncEngine
,
AsyncEngineContextProvider
,
Error
,
ManyOut
,
PushRouter
,
ResponseStream
,
SingleIn
,
async_trait
,
SingleIn
,
async_trait
,
...
@@ -41,7 +41,7 @@ pub use prefill_router::PrefillRouter;
...
@@ -41,7 +41,7 @@ pub use prefill_router::PrefillRouter;
use
worker_query
::
WorkerQueryClient
;
use
worker_query
::
WorkerQueryClient
;
use
crate
::{
use
crate
::{
discovery
::
RuntimeConfigs
WithNotify
,
discovery
::
RuntimeConfigs
,
kv_router
::{
kv_router
::{
approx
::
PruneConfig
,
approx
::
PruneConfig
,
indexer
::{
KvIndexer
,
KvIndexerInterface
,
KvRouterError
},
indexer
::{
KvIndexer
,
KvIndexerInterface
,
KvRouterError
},
...
@@ -55,7 +55,6 @@ use crate::{
...
@@ -55,7 +55,6 @@ use crate::{
subscriber
::{
start_kv_router_background
,
start_kv_router_background_event_plane
},
subscriber
::{
start_kv_router_background
,
start_kv_router_background_event_plane
},
},
},
local_model
::
runtime_config
::
ModelRuntimeConfig
,
local_model
::
runtime_config
::
ModelRuntimeConfig
,
model_card
::
ModelDeploymentCard
,
preprocessor
::
PreprocessedRequest
,
preprocessor
::
PreprocessedRequest
,
protocols
::
common
::
llm_backend
::
LLMEngineOutput
,
protocols
::
common
::
llm_backend
::
LLMEngineOutput
,
protocols
::
common
::
timing
::
RequestPhase
,
protocols
::
common
::
timing
::
RequestPhase
,
...
@@ -332,7 +331,7 @@ impl KvRouter {
...
@@ -332,7 +331,7 @@ impl KvRouter {
pub
async
fn
new
(
pub
async
fn
new
(
endpoint
:
Endpoint
,
endpoint
:
Endpoint
,
client
:
Client
,
client
:
Client
,
workers_with_configs
:
Arc
<
RuntimeConfigs
WithNotify
>
,
workers_with_configs
:
Arc
<
RuntimeConfigs
>
,
block_size
:
u32
,
block_size
:
u32
,
selector
:
Option
<
Box
<
dyn
WorkerSelector
+
Send
+
Sync
>>
,
selector
:
Option
<
Box
<
dyn
WorkerSelector
+
Send
+
Sync
>>
,
kv_router_config
:
Option
<
KvRouterConfig
>
,
kv_router_config
:
Option
<
KvRouterConfig
>
,
...
@@ -342,23 +341,6 @@ impl KvRouter {
...
@@ -342,23 +341,6 @@ impl KvRouter {
let
component
=
endpoint
.component
();
let
component
=
endpoint
.component
();
let
cancellation_token
=
component
.drt
()
.primary_token
();
let
cancellation_token
=
component
.drt
()
.primary_token
();
// Watch for runtime config updates via discovery interface
// (still needed for WorkerQueryClient and background tasks)
let
discovery
=
component
.drt
()
.discovery
();
let
endpoint_id
=
endpoint
.id
();
let
discovery_key
=
DiscoveryQuery
::
EndpointModels
{
namespace
:
endpoint_id
.namespace
.clone
(),
component
:
endpoint_id
.component
.clone
(),
endpoint
:
endpoint_id
.name
.clone
(),
};
let
discovery_stream
=
discovery
.list_and_watch
(
discovery_key
.clone
(),
Some
(
cancellation_token
.clone
()))
.await
?
;
let
runtime_configs_rx
=
watch_and_extract_field
(
discovery_stream
,
|
card
:
ModelDeploymentCard
|
{
card
.runtime_config
});
let
indexer
=
if
kv_router_config
.overlap_score_weight
==
0.0
{
let
indexer
=
if
kv_router_config
.overlap_score_weight
==
0.0
{
// When overlap_score_weight is zero, we don't need to track prefixes
// When overlap_score_weight is zero, we don't need to track prefixes
Indexer
::
None
Indexer
::
None
...
@@ -385,6 +367,9 @@ impl KvRouter {
...
@@ -385,6 +367,9 @@ impl KvRouter {
))
))
};
};
// Wait for at least one worker with a known runtime config before starting scheduler
workers_with_configs
.subscribe
()
.wait_for_some
()
.await
;
let
scheduler
=
KvScheduler
::
start
(
let
scheduler
=
KvScheduler
::
start
(
component
.clone
(),
component
.clone
(),
block_size
,
block_size
,
...
@@ -397,30 +382,27 @@ impl KvRouter {
...
@@ -397,30 +382,27 @@ impl KvRouter {
// Initialize worker query client using namespace abstraction
// Initialize worker query client using namespace abstraction
// (created before background task so we can use it for startup recovery)
// (created before background task so we can use it for startup recovery)
let
worker_query_client
=
// Uses a subscriber from workers_with_configs
worker_query
::
WorkerQueryClient
::
new
(
component
.clone
(),
runtime_configs_rx
.clone
());
let
worker_query_client
=
worker_query
::
WorkerQueryClient
::
new
(
component
.clone
(),
workers_with_configs
.subscribe
(),
);
tracing
::
info!
(
"Worker query client initialized"
);
tracing
::
info!
(
"Worker query client initialized"
);
// Start KV event subscriber background process (only when use_kv_events is enabled)
// Start KV event subscriber background process (only when use_kv_events is enabled)
// model_manager.get_or_create_runtime_config_watcher() guarantees at least one worker exists.
if
kv_router_config
.use_kv_events
if
kv_router_config
.use_kv_events
&&
let
Indexer
::
KvIndexer
(
ref
kv_indexer
)
=
indexer
&&
let
Indexer
::
KvIndexer
(
ref
kv_indexer
)
=
indexer
{
{
// model_manager guarantees workers_with_configs is populated
// Wait for at least one worker before starting the subscriber
while
workers_with_configs
.configs
.is_empty
()
{
tracing
::
info!
(
"KV router waiting for at least one worker..."
);
workers_with_configs
.notify
.notified
()
.await
;
}
let
count
=
workers_with_configs
.configs
.len
();
let
all_local_indexer
=
workers_with_configs
let
all_local_indexer
=
workers_with_configs
.configs
.configs
.iter
()
.iter
()
.filter_map
(|
r
|
r
.value
()
.as_ref
()
.map
(|
c
|
c
.enable_local_indexer
))
.filter_map
(|
r
|
r
.value
()
.as_ref
()
.map
(|
c
|
c
.enable_local_indexer
))
.all
(|
b
|
b
);
.all
(|
b
|
b
);
tracing
::
info!
(
"Found {count} worker(s), starting KV event subscriber"
);
tracing
::
info!
(
"Found {} worker(s), starting KV event subscriber"
,
workers_with_configs
.num_workers
()
);
let
transport_kind
=
EventTransportKind
::
from_env_or_default
();
let
transport_kind
=
EventTransportKind
::
from_env_or_default
();
...
@@ -436,7 +418,8 @@ impl KvRouter {
...
@@ -436,7 +418,8 @@ impl KvRouter {
}
}
}
else
{
}
else
{
tracing
::
info!
(
tracing
::
info!
(
"All {count} workers have local_indexer enabled, using NATS Core subscription"
"All {} workers have local_indexer enabled, using NATS Core subscription"
,
workers_with_configs
.num_workers
()
);
);
}
}
...
@@ -447,7 +430,7 @@ impl KvRouter {
...
@@ -447,7 +430,7 @@ impl KvRouter {
cancellation_token
.clone
(),
cancellation_token
.clone
(),
worker_query
::
WorkerQueryClient
::
new
(
worker_query
::
WorkerQueryClient
::
new
(
component
.clone
(),
component
.clone
(),
runtime_configs_rx
.clon
e
(),
workers_with_configs
.subscrib
e
(),
),
),
transport_kind
,
transport_kind
,
)
)
...
...
lib/llm/src/kv_router/publisher.rs
View file @
584020f4
...
@@ -921,7 +921,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
...
@@ -921,7 +921,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
// -------------------------------------------------------------------------
// -------------------------------------------------------------------------
/// Metrics data passed through the channel for NATS publishing
/// Metrics data passed through the channel for NATS publishing
#[derive(Debug,
Clone,
Default)]
#[derive(Debug,
Clone,
Default
,
PartialEq
)]
struct
WorkerMetrics
{
struct
WorkerMetrics
{
dp_rank
:
DpRank
,
dp_rank
:
DpRank
,
active_decode_blocks
:
u64
,
active_decode_blocks
:
u64
,
...
@@ -982,7 +982,7 @@ impl WorkerMetricsPublisher {
...
@@ -982,7 +982,7 @@ impl WorkerMetricsPublisher {
};
};
let
mut
rx
=
nats_rx
;
let
mut
rx
=
nats_rx
;
let
mut
last_
active_decode_blocks
:
Option
<
u64
>
=
Some
(
0
)
;
let
mut
last_
metrics
:
Option
<
WorkerMetrics
>
=
None
;
let
mut
pending_publish
:
Option
<
WorkerMetrics
>
=
None
;
let
mut
pending_publish
:
Option
<
WorkerMetrics
>
=
None
;
let
mut
publish_timer
=
let
mut
publish_timer
=
Box
::
pin
(
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_secs
(
0
)));
Box
::
pin
(
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_secs
(
0
)));
...
@@ -1001,16 +1001,13 @@ impl WorkerMetricsPublisher {
...
@@ -1001,16 +1001,13 @@ impl WorkerMetricsPublisher {
let
metrics
=
rx
.borrow_and_update
()
.clone
();
let
metrics
=
rx
.borrow_and_update
()
.clone
();
// Check if active_decode_blocks has changed
// Check if metrics have changed
let
has_changed
=
match
last_active_decode_blocks
{
let
has_changed
=
last_metrics
.as_ref
()
!=
Some
(
&
metrics
);
Some
(
last
)
=>
last
!=
metrics
.active_decode_blocks
,
None
=>
true
,
// First time, consider it changed
};
// If
load
metrics changed, schedule a publish
// If metrics changed, schedule a publish
if
has_changed
{
if
has_changed
{
pending_publish
=
Some
(
metrics
.clone
());
pending_publish
=
Some
(
metrics
.clone
());
last_
active_decode_blocks
=
Some
(
metrics
.active_decode_block
s
);
last_
metrics
=
Some
(
metric
s
);
// Start the 1ms timer
// Start the 1ms timer
publish_timer
.as_mut
()
.reset
(
publish_timer
.as_mut
()
.reset
(
...
...
lib/llm/src/kv_router/scheduler.rs
View file @
584020f4
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
// SPDX-License-Identifier: Apache-2.0
use
crate
::
discovery
::
RuntimeConfigs
WithNotify
;
use
crate
::
discovery
::
RuntimeConfigs
;
use
crate
::
local_model
::
runtime_config
::
ModelRuntimeConfig
;
use
crate
::
local_model
::
runtime_config
::
ModelRuntimeConfig
;
use
anyhow
::
Result
;
use
anyhow
::
Result
;
use
dynamo_runtime
::
component
::
Component
;
use
dynamo_runtime
::
component
::
Component
;
...
@@ -98,7 +98,7 @@ impl KvScheduler {
...
@@ -98,7 +98,7 @@ impl KvScheduler {
pub
async
fn
start
(
pub
async
fn
start
(
component
:
Component
,
component
:
Component
,
block_size
:
u32
,
block_size
:
u32
,
workers_with_configs
:
Arc
<
RuntimeConfigs
WithNotify
>
,
workers_with_configs
:
Arc
<
RuntimeConfigs
>
,
selector
:
Option
<
Box
<
dyn
WorkerSelector
+
Send
+
Sync
>>
,
selector
:
Option
<
Box
<
dyn
WorkerSelector
+
Send
+
Sync
>>
,
replica_sync
:
bool
,
replica_sync
:
bool
,
router_id
:
u64
,
router_id
:
u64
,
...
@@ -106,7 +106,7 @@ impl KvScheduler {
...
@@ -106,7 +106,7 @@ impl KvScheduler {
let
selector
=
selector
.unwrap_or
(
Box
::
new
(
DefaultWorkerSelector
::
default
()));
let
selector
=
selector
.unwrap_or
(
Box
::
new
(
DefaultWorkerSelector
::
default
()));
// Get initial workers from DashMap for slot initialization.
// Get initial workers from DashMap for slot initialization.
//
ModelManager guarantees
at least one worker is present
before KvRouter::new() is called
.
//
Caller must ensure
at least one worker is present
(via wait_for_some)
.
let
initial_workers
:
HashMap
<
WorkerId
,
Option
<
ModelRuntimeConfig
>>
=
workers_with_configs
let
initial_workers
:
HashMap
<
WorkerId
,
Option
<
ModelRuntimeConfig
>>
=
workers_with_configs
.configs
.configs
.iter
()
.iter
()
...
@@ -126,9 +126,11 @@ impl KvScheduler {
...
@@ -126,9 +126,11 @@ impl KvScheduler {
);
);
// Spawn background task to sync slots with DashMap when notified of changes.
// Spawn background task to sync slots with DashMap when notified of changes.
// ModelManager's watcher updates the DashMap and notifies; we wait on
notify
here.
// ModelManager's watcher updates the DashMap and notifies; we wait on
watch receiver
here.
let
slots_monitor
=
slots
.clone
();
let
slots_monitor
=
slots
.clone
();
let
workers_monitor
=
workers_with_configs
.clone
();
let
subscriber
=
workers_with_configs
.subscribe
();
let
configs_monitor
=
subscriber
.configs
;
let
mut
change_rx
=
subscriber
.change_rx
;
let
monitor_cancel_token
=
component
.drt
()
.child_token
();
let
monitor_cancel_token
=
component
.drt
()
.child_token
();
tokio
::
spawn
(
async
move
{
tokio
::
spawn
(
async
move
{
tracing
::
trace!
(
"KvScheduler workers monitoring task started"
);
tracing
::
trace!
(
"KvScheduler workers monitoring task started"
);
...
@@ -141,13 +143,17 @@ impl KvScheduler {
...
@@ -141,13 +143,17 @@ impl KvScheduler {
tracing
::
trace!
(
"KvScheduler workers monitoring task shutting down"
);
tracing
::
trace!
(
"KvScheduler workers monitoring task shutting down"
);
break
;
break
;
}
}
_
=
workers_monitor
.notify
.notified
()
=>
{}
result
=
change_rx
.changed
()
=>
{
if
result
.is_err
()
{
tracing
::
warn!
(
"KvScheduler: config watch sender dropped, shutting down"
);
break
;
}
}
}
}
// Get current workers from DashMap
// Get current workers from DashMap
let
current_workers
:
HashMap
<
WorkerId
,
Option
<
ModelRuntimeConfig
>>
=
let
current_workers
:
HashMap
<
WorkerId
,
Option
<
ModelRuntimeConfig
>>
=
workers_monitor
configs_monitor
.configs
.iter
()
.iter
()
.map
(|
r
|
(
*
r
.key
(),
r
.value
()
.clone
()))
.map
(|
r
|
(
*
r
.key
(),
r
.value
()
.clone
()))
.collect
();
.collect
();
...
...
lib/llm/src/kv_router/subscriber.rs
View file @
584020f4
...
@@ -55,9 +55,9 @@ const WORKER_QUERY_INITIAL_BACKOFF_MS: u64 = 200;
...
@@ -55,9 +55,9 @@ const WORKER_QUERY_INITIAL_BACKOFF_MS: u64 = 200;
// Discovery Helpers
// Discovery Helpers
// ============================================================================
// ============================================================================
///
Wait for at least one worker instance to be discovered
.
///
Get the instance discovery stream for monitoring worker add/remove events
.
///
Returns a peekable stream of discovery events for the generate endpo
in
t
.
///
Waits for at least one instance to be discovered before return
in
g
.
async
fn
wait_for_worker_instance
(
async
fn
get_instance_discovery_stream
(
component
:
&
Component
,
component
:
&
Component
,
cancellation_token
:
&
CancellationToken
,
cancellation_token
:
&
CancellationToken
,
)
->
Result
<
std
::
pin
::
Pin
<
Box
<
dyn
futures
::
Stream
<
Item
=
Result
<
DiscoveryEvent
>>
+
Send
>>>
{
)
->
Result
<
std
::
pin
::
Pin
<
Box
<
dyn
futures
::
Stream
<
Item
=
Result
<
DiscoveryEvent
>>
+
Send
>>>
{
...
@@ -524,7 +524,7 @@ pub async fn start_kv_router_background(
...
@@ -524,7 +524,7 @@ pub async fn start_kv_router_background(
// Wait for at least one worker instance before proceeding
// Wait for at least one worker instance before proceeding
let
mut
instance_event_stream
=
let
mut
instance_event_stream
=
wait_for_worker_instance
(
&
component
,
&
cancellation_token
)
.await
?
;
get_instance_discovery_stream
(
&
component
,
&
cancellation_token
)
.await
?
;
// Watch for router deletions to clean up orphaned consumers via discovery
// Watch for router deletions to clean up orphaned consumers via discovery
let
generate_endpoint
=
component
.endpoint
(
"generate"
);
let
generate_endpoint
=
component
.endpoint
(
"generate"
);
...
@@ -762,7 +762,7 @@ pub async fn start_kv_router_background_event_plane(
...
@@ -762,7 +762,7 @@ pub async fn start_kv_router_background_event_plane(
kv_events_tx
:
mpsc
::
Sender
<
RouterEvent
>
,
kv_events_tx
:
mpsc
::
Sender
<
RouterEvent
>
,
remove_worker_tx
:
mpsc
::
Sender
<
WorkerId
>
,
remove_worker_tx
:
mpsc
::
Sender
<
WorkerId
>
,
cancellation_token
:
CancellationToken
,
cancellation_token
:
CancellationToken
,
worker_query_client
:
WorkerQueryClient
,
mut
worker_query_client
:
WorkerQueryClient
,
transport_kind
:
EventTransportKind
,
transport_kind
:
EventTransportKind
,
)
->
Result
<
()
>
{
)
->
Result
<
()
>
{
// Subscribe to KV events using the selected event plane transport
// Subscribe to KV events using the selected event plane transport
...
@@ -792,44 +792,35 @@ pub async fn start_kv_router_background_event_plane(
...
@@ -792,44 +792,35 @@ pub async fn start_kv_router_background_event_plane(
}
}
}
}
// Wait for at least one worker instance before proceeding
// Wait for at least one worker with a known runtime config before proceeding.
let
mut
instance_event_stream
=
// This ensures we have actual config data (including enable_local_indexer) available.
wait_for_worker_instance
(
&
component
,
&
cancellation_token
)
.await
?
;
tracing
::
info!
(
"KV subscriber waiting for at least one worker with runtime config..."
);
let
ready_workers
=
worker_query_client
.wait_for_ready
()
.await
;
// Drain and process all existing workers before spawning the background loop.
tracing
::
info!
(
// list_and_watch returns existing instances first, so we poll with a short timeout
"KV subscriber found {} worker(s) with runtime config, proceeding"
,
// to process all initial workers synchronously before the router becomes "ready".
ready_workers
.len
()
loop
{
);
// Use a short timeout to detect when initial discovery events are exhausted
let
poll_result
=
// Recover initial state from all ready workers
tokio
::
time
::
timeout
(
Duration
::
from_millis
(
100
),
instance_event_stream
.next
())
.await
;
for
worker_id
in
&
ready_workers
{
if
worker_query_client
.has_local_indexer
(
*
worker_id
)
{
match
poll_result
{
match
recover_from_worker
(
&
worker_query_client
,
*
worker_id
,
None
,
None
,
&
kv_events_tx
)
Ok
(
Some
(
Ok
(
event
)))
=>
{
.await
handle_worker_discovery
(
{
event
,
Ok
(
count
)
=>
{
&
worker_query_client
,
tracing
::
info!
(
"Successfully recovered {count} events from worker {worker_id}"
);
&
kv_events_tx
,
}
&
remove_worker_tx
,
Err
(
e
)
=>
{
)
tracing
::
warn!
(
"Failed to recover from worker {worker_id}: {e}"
);
.await
;
}
}
Ok
(
Some
(
Err
(
e
)))
=>
{
tracing
::
warn!
(
"Error receiving discovery event during initial sync: {e}"
);
}
Ok
(
None
)
=>
{
// Stream ended
tracing
::
warn!
(
"Discovery stream ended during initial sync"
);
break
;
}
Err
(
_
)
=>
{
// Timeout - no more initial events
tracing
::
debug!
(
"Initial worker discovery sync complete"
);
break
;
}
}
}
}
}
}
// Get instance discovery stream for ongoing monitoring of worker add/remove events
let
mut
instance_event_stream
=
get_instance_discovery_stream
(
&
component
,
&
cancellation_token
)
.await
?
;
tokio
::
spawn
(
async
move
{
tokio
::
spawn
(
async
move
{
// Track last received event ID per worker for gap detection
// Track last received event ID per worker for gap detection
let
mut
last_event_ids
:
HashMap
<
WorkerId
,
u64
>
=
HashMap
::
new
();
let
mut
last_event_ids
:
HashMap
<
WorkerId
,
u64
>
=
HashMap
::
new
();
...
...
lib/llm/src/kv_router/worker_query.rs
View file @
584020f4
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
// SPDX-License-Identifier: Apache-2.0
use
std
::
collections
::
HashMap
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
anyhow
::{
Context
,
Result
};
use
anyhow
::{
Context
,
Result
};
...
@@ -11,46 +10,49 @@ use dynamo_runtime::pipeline::{
...
@@ -11,46 +10,49 @@ use dynamo_runtime::pipeline::{
SingleIn
,
async_trait
,
network
::
Ingress
,
SingleIn
,
async_trait
,
network
::
Ingress
,
};
};
use
dynamo_runtime
::
protocols
::
maybe_error
::
MaybeError
;
use
dynamo_runtime
::
protocols
::
maybe_error
::
MaybeError
;
use
tokio
::
sync
::
{
OnceCell
,
watch
}
;
use
tokio
::
sync
::
OnceCell
;
use
tokio_stream
::
StreamExt
;
use
tokio_stream
::
StreamExt
;
use
crate
::
discovery
::
RuntimeConfigsSubscriber
;
use
crate
::
kv_router
::
WORKER_KV_INDEXER_QUERY_ENDPOINT
;
use
crate
::
kv_router
::
WORKER_KV_INDEXER_QUERY_ENDPOINT
;
use
crate
::
kv_router
::
indexer
::{
LocalKvIndexer
,
WorkerKvQueryRequest
,
WorkerKvQueryResponse
};
use
crate
::
kv_router
::
indexer
::{
LocalKvIndexer
,
WorkerKvQueryRequest
,
WorkerKvQueryResponse
};
use
crate
::
kv_router
::
protocols
::
WorkerId
;
use
crate
::
kv_router
::
protocols
::
WorkerId
;
use
crate
::
local_model
::
runtime_config
::
ModelRuntimeConfig
;
use
dynamo_runtime
::
stream
;
use
dynamo_runtime
::
stream
;
/// Router-side client for querying worker local KV indexers
/// Router-side client for querying worker local KV indexers
///
///
/// Performs request/reply communication with workers via request plane endpoint routing.
/// Performs request/reply communication with workers via request plane endpoint routing.
/// (Only queries workers that have `enable_local_indexer=true` in their MDC user_data)
/// (Only queries workers that have `enable_local_indexer=true` in their MDC user_data)
/// The client is spawned by KvRouter; it
watches same discovery stream as the router
.
/// The client is spawned by KvRouter; it
uses a subscriber from RuntimeConfigs
.
pub
struct
WorkerQueryClient
{
pub
struct
WorkerQueryClient
{
component
:
Component
,
component
:
Component
,
///
Watch receiver for enable_local_indexer state per worker
///
Subscriber for runtime configs (includes shared configs DashMap)
model_runtime_config_rx
:
watch
::
Receiver
<
HashMap
<
WorkerId
,
Model
RuntimeConfig
>>
,
subscriber
:
RuntimeConfig
sSubscriber
,
router
:
OnceCell
<
Arc
<
PushRouter
<
WorkerKvQueryRequest
,
WorkerKvQueryResponse
>>>
,
router
:
OnceCell
<
Arc
<
PushRouter
<
WorkerKvQueryRequest
,
WorkerKvQueryResponse
>>>
,
}
}
impl
WorkerQueryClient
{
impl
WorkerQueryClient
{
/// Create a new WorkerQueryClient with a watch receiver for local indexer states
/// Create a new WorkerQueryClient with a subscriber to runtime configs
pub
fn
new
(
pub
fn
new
(
component
:
Component
,
subscriber
:
RuntimeConfigsSubscriber
)
->
Self
{
component
:
Component
,
model_runtime_config_rx
:
watch
::
Receiver
<
HashMap
<
WorkerId
,
ModelRuntimeConfig
>>
,
)
->
Self
{
Self
{
Self
{
component
,
component
,
model_runtime_config_rx
,
subscriber
,
router
:
OnceCell
::
new
(),
router
:
OnceCell
::
new
(),
}
}
}
}
/// Wait until at least one worker has a known runtime config (Some).
/// Returns the list of worker IDs that have configs.
pub
async
fn
wait_for_ready
(
&
mut
self
)
->
Vec
<
WorkerId
>
{
self
.subscriber
.wait_for_some
()
.await
}
/// Check if a worker has local indexer enabled
/// Check if a worker has local indexer enabled
pub
fn
has_local_indexer
(
&
self
,
worker_id
:
WorkerId
)
->
bool
{
pub
fn
has_local_indexer
(
&
self
,
worker_id
:
WorkerId
)
->
bool
{
self
.
model_runtime_config_rx
self
.
subscriber
.
borrow
()
.
configs
.get
(
&
worker_id
)
.get
(
&
worker_id
)
.
map
(|
config
|
config
.enable_local_indexer
)
.
and_then
(|
entry
|
entry
.value
()
.as_ref
()
.map
(|
c
|
c
.enable_local_indexer
)
)
.unwrap_or
(
false
)
.unwrap_or
(
false
)
}
}
...
...
lib/runtime/src/config/environment_names.rs
View file @
584020f4
...
@@ -301,6 +301,10 @@ pub mod model {
...
@@ -301,6 +301,10 @@ pub mod model {
/// Hugging Face home directory
/// Hugging Face home directory
pub
const
HF_HOME
:
&
str
=
"HF_HOME"
;
pub
const
HF_HOME
:
&
str
=
"HF_HOME"
;
/// Offline mode - skip API calls when model is cached
/// Set to "1" or "true" to enable
pub
const
HF_HUB_OFFLINE
:
&
str
=
"HF_HUB_OFFLINE"
;
}
}
}
}
...
@@ -436,6 +440,7 @@ mod tests {
...
@@ -436,6 +440,7 @@ mod tests {
model
::
huggingface
::
HF_TOKEN
,
model
::
huggingface
::
HF_TOKEN
,
model
::
huggingface
::
HF_HUB_CACHE
,
model
::
huggingface
::
HF_HUB_CACHE
,
model
::
huggingface
::
HF_HOME
,
model
::
huggingface
::
HF_HOME
,
model
::
huggingface
::
HF_HUB_OFFLINE
,
// Event Plane
// Event Plane
event_plane
::
DYN_EVENT_PLANE
,
event_plane
::
DYN_EVENT_PLANE
,
event_plane
::
DYN_EVENT_PLANE_CODEC
,
event_plane
::
DYN_EVENT_PLANE_CODEC
,
...
...
lib/runtime/src/transports/nats.rs
View file @
584020f4
...
@@ -335,12 +335,6 @@ impl ClientOptions {
...
@@ -335,12 +335,6 @@ impl ClientOptions {
let
js_ctx
=
jetstream
::
new
(
client
.clone
());
let
js_ctx
=
jetstream
::
new
(
client
.clone
());
// Validate JetStream is available
js_ctx
.query_account
()
.await
.map_err
(|
e
|
anyhow
::
anyhow!
(
"JetStream not available: {e}"
))
?
;
Ok
(
Client
{
client
,
js_ctx
})
Ok
(
Client
{
client
,
js_ctx
})
}
}
}
}
...
...
tests/conftest.py
View file @
584020f4
...
@@ -298,7 +298,7 @@ class EtcdServer(ManagedProcess):
...
@@ -298,7 +298,7 @@ class EtcdServer(ManagedProcess):
class
NatsServer
(
ManagedProcess
):
class
NatsServer
(
ManagedProcess
):
def
__init__
(
self
,
request
,
port
=
4222
,
timeout
=
300
):
def
__init__
(
self
,
request
,
port
=
4222
,
timeout
=
300
,
disable_jetstream
=
False
):
# Allocate a free port if port is 0
# Allocate a free port if port is 0
use_random_port
=
port
==
0
use_random_port
=
port
==
0
if
use_random_port
:
if
use_random_port
:
...
@@ -309,16 +309,16 @@ class NatsServer(ManagedProcess):
...
@@ -309,16 +309,16 @@ class NatsServer(ManagedProcess):
self
.
use_random_port
=
use_random_port
# Track if we allocated the port
self
.
use_random_port
=
use_random_port
# Track if we allocated the port
self
.
_request
=
request
# Store for restart
self
.
_request
=
request
# Store for restart
self
.
_timeout
=
timeout
self
.
_timeout
=
timeout
data_dir
=
tempfile
.
mkdtemp
(
prefix
=
"nats_"
)
self
.
_disable_jetstream
=
disable_jetstream
data_dir
=
tempfile
.
mkdtemp
(
prefix
=
"nats_"
)
if
not
disable_jetstream
else
None
command
=
[
command
=
[
"nats-server"
,
"nats-server"
,
"-js"
,
"--trace"
,
"--trace"
,
"--store_dir"
,
data_dir
,
"-p"
,
"-p"
,
str
(
port
),
str
(
port
),
]
]
if
not
disable_jetstream
and
data_dir
:
command
.
extend
([
"-js"
,
"--store_dir"
,
data_dir
])
super
().
__init__
(
super
().
__init__
(
command
=
command
,
command
=
command
,
timeout
=
timeout
,
timeout
=
timeout
,
...
@@ -326,9 +326,45 @@ class NatsServer(ManagedProcess):
...
@@ -326,9 +326,45 @@ class NatsServer(ManagedProcess):
terminate_existing
=
not
use_random_port
,
# Disabled for parallel test execution with random ports
terminate_existing
=
not
use_random_port
,
# Disabled for parallel test execution with random ports
data_dir
=
data_dir
,
data_dir
=
data_dir
,
health_check_ports
=
[
port
],
health_check_ports
=
[
port
],
health_check_funcs
=
[
self
.
_nats_ready
],
log_dir
=
request
.
node
.
name
,
log_dir
=
request
.
node
.
name
,
)
)
def
_nats_ready
(
self
,
timeout
:
float
=
5
)
->
bool
:
"""Verify NATS server is ready by connecting and optionally checking JetStream."""
import
asyncio
import
nats
async
def
check
():
try
:
nc
=
await
nats
.
connect
(
f
"nats://localhost:
{
self
.
port
}
"
,
connect_timeout
=
min
(
timeout
,
2
),
)
try
:
if
not
self
.
_disable_jetstream
:
# Verify JetStream is initialized
js
=
nc
.
jetstream
()
await
js
.
account_info
()
return
True
finally
:
await
nc
.
close
()
except
Exception
:
return
False
# Handle both sync and async contexts
try
:
asyncio
.
get_running_loop
()
# Check if we're in async context
# Already in async context - run in a thread to avoid blocking
import
concurrent.futures
with
concurrent
.
futures
.
ThreadPoolExecutor
()
as
pool
:
return
pool
.
submit
(
asyncio
.
run
,
check
()).
result
(
timeout
=
timeout
)
except
RuntimeError
:
# No running loop - safe to use asyncio.run()
return
asyncio
.
run
(
check
())
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
"""Release allocated port when server exits."""
"""Release allocated port when server exits."""
try
:
try
:
...
@@ -344,9 +380,10 @@ class NatsServer(ManagedProcess):
...
@@ -344,9 +380,10 @@ class NatsServer(ManagedProcess):
"""Stop the NATS server for restart. Does not release port or clean up fully."""
"""Stop the NATS server for restart. Does not release port or clean up fully."""
_logger
.
info
(
f
"Stopping NATS server on port
{
self
.
port
}
"
)
_logger
.
info
(
f
"Stopping NATS server on port
{
self
.
port
}
"
)
self
.
_terminate_process_group
()
self
.
_terminate_process_group
()
if
self
.
proc
:
proc
=
self
.
proc
# type: ignore[has-type]
if
proc
is
not
None
:
try
:
try
:
self
.
proc
.
wait
(
timeout
=
10
)
proc
.
wait
(
timeout
=
10
)
except
Exception
as
e
:
except
Exception
as
e
:
_logger
.
warning
(
f
"Error waiting for NATS process to stop:
{
e
}
"
)
_logger
.
warning
(
f
"Error waiting for NATS process to stop:
{
e
}
"
)
self
.
proc
=
None
self
.
proc
=
None
...
@@ -354,130 +391,130 @@ class NatsServer(ManagedProcess):
...
@@ -354,130 +391,130 @@ class NatsServer(ManagedProcess):
def
start
(
self
):
def
start
(
self
):
"""Restart a stopped NATS server with fresh state."""
"""Restart a stopped NATS server with fresh state."""
_logger
.
info
(
f
"Starting NATS server on port
{
self
.
port
}
with fresh state"
)
_logger
.
info
(
f
"Starting NATS server on port
{
self
.
port
}
with fresh state"
)
# Clean up old data directory and create fresh one
# Clean up old data directory and create fresh one (only if JetStream enabled)
if
self
.
data_dir
:
if
not
self
.
_disable_jetstream
:
shutil
.
rmtree
(
self
.
data_dir
,
ignore_errors
=
True
)
old_data_dir
=
self
.
data_dir
# type: ignore[has-type]
self
.
data_dir
=
tempfile
.
mkdtemp
(
prefix
=
"nats_"
)
if
old_data_dir
is
not
None
:
shutil
.
rmtree
(
old_data_dir
,
ignore_errors
=
True
)
# Rebuild command with new data_dir
self
.
data_dir
=
tempfile
.
mkdtemp
(
prefix
=
"nats_"
)
# Rebuild command
self
.
command
=
[
self
.
command
=
[
"nats-server"
,
"nats-server"
,
"-js"
,
"--trace"
,
"--trace"
,
"--store_dir"
,
self
.
data_dir
,
"-p"
,
"-p"
,
str
(
self
.
port
),
str
(
self
.
port
),
]
]
if
not
self
.
_disable_jetstream
and
self
.
data_dir
:
self
.
command
.
extend
([
"-js"
,
"--store_dir"
,
self
.
data_dir
])
self
.
_start_process
()
self
.
_start_process
()
self
.
_check_ports
(
self
.
_timeout
)
elapsed
=
self
.
_check_ports
(
self
.
_timeout
)
self
.
_check_funcs
(
self
.
_timeout
-
elapsed
)
class
SharedManagedProcess
:
class
SharedManagedProcess
:
"""Base class for ManagedProcess with file-based reference counting for multi-process sharing."""
"""Base class for persistent shared processes across pytest-xdist workers.
Simplified design: first worker starts the process on a dynamic port, it lives forever
(until the container dies). No ref counting, no teardown. Subsequent workers just
reuse via port check. This eliminates race conditions and simplifies the logic.
"""
def
__init__
(
def
__init__
(
self
,
self
,
request
,
request
,
tmp_path_factory
,
tmp_path_factory
,
resource_name
:
str
,
resource_name
:
str
,
port
:
int
,
start_
port
:
int
,
timeout
:
int
=
300
,
timeout
:
int
=
300
,
):
):
self
.
request
=
request
self
.
request
=
request
self
.
port
=
port
self
.
start_port
=
start_port
self
.
port
:
Optional
[
int
]
=
None
# Set when entering context
self
.
timeout
=
timeout
self
.
timeout
=
timeout
self
.
resource_name
=
resource_name
self
.
resource_name
=
resource_name
self
.
_server
:
Optional
[
ManagedProcess
]
=
None
self
.
_server
:
Optional
[
ManagedProcess
]
=
None
self
.
_owns_process
=
False
root_tmp
=
Path
(
tempfile
.
gettempdir
())
/
"pytest_
ref_counting
"
root_tmp
=
Path
(
tempfile
.
gettempdir
())
/
"pytest_
shared_services
"
root_tmp
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
root_tmp
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
self
.
ref
_file
=
root_tmp
/
f
"
pytest_
{
resource_name
}
_
{
port
}
_ref_count
"
self
.
port
_file
=
root_tmp
/
f
"
{
resource_name
}
_port"
self
.
lock_file
=
str
(
self
.
ref
_file
)
+
".lock"
self
.
lock_file
=
str
(
self
.
port
_file
)
+
".lock"
def
_create_server
(
self
)
->
ManagedProcess
:
def
_create_server
(
self
,
port
:
int
)
->
ManagedProcess
:
"""Create the underlying server instance. Must be implemented by subclasses."""
"""Create the underlying server instance. Must be implemented by subclasses."""
raise
NotImplementedError
raise
NotImplementedError
def
_read_ref_count
(
self
)
->
int
:
def
_is_port_in_use
(
self
,
port
:
int
)
->
bool
:
"""Read current reference count."""
"""Check if a port is in use (i.e., a process is listening on it)."""
if
self
.
ref_file
.
exists
():
import
socket
try
:
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
.
settimeout
(
1
)
result
=
sock
.
connect_ex
((
"localhost"
,
port
))
sock
.
close
()
return
result
==
0
# 0 means connection succeeded (port in use)
except
Exception
:
return
False
def
_read_port
(
self
)
->
Optional
[
int
]:
"""Read stored port from file."""
if
self
.
port_file
.
exists
():
try
:
try
:
return
int
(
self
.
ref
_file
.
read_text
().
strip
())
return
int
(
self
.
port
_file
.
read_text
().
strip
())
except
(
ValueError
,
IOError
):
except
(
ValueError
,
IOError
):
return
0
return
None
return
0
return
None
def
_write_ref_count
(
self
,
count
:
int
):
def
_write_port
(
self
,
port
:
int
):
"""Write reference count atomically."""
"""Write port to file."""
self
.
ref_file
.
write_text
(
str
(
count
))
self
.
port_file
.
write_text
(
str
(
port
))
def
_increment_ref_count
(
self
)
->
int
:
"""Increment reference count and return new count."""
count
=
self
.
_read_ref_count
()
count
+=
1
self
.
_write_ref_count
(
count
)
return
count
def
_decrement_ref_count
(
self
)
->
int
:
"""Decrement reference count and return new count."""
count
=
self
.
_read_ref_count
()
count
=
max
(
0
,
count
-
1
)
self
.
_write_ref_count
(
count
)
return
count
def
__enter__
(
self
):
def
__enter__
(
self
):
with
FileLock
(
self
.
lock_file
):
with
FileLock
(
self
.
lock_file
):
ref_count
=
self
.
_increment_ref_count
()
stored_port
=
self
.
_read_port
()
if
ref_count
==
1
:
# First reference - start the process
# Check if a process is already running on the stored port
self
.
_server
=
self
.
_create_server
()
if
stored_port
is
not
None
and
self
.
_is_port_in_use
(
stored_port
):
self
.
_server
.
__enter__
()
# Reuse existing process
self
.
_owns_process
=
True
self
.
port
=
stored_port
logging
.
info
(
f
"[
{
self
.
resource_name
}
] Started process (ref_count=1)"
)
logging
.
info
(
f
"[
{
self
.
resource_name
}
] Reusing existing process on port
{
self
.
port
}
"
)
else
:
else
:
# Process already running, just track reference
# Start new process
self
.
_owns_process
=
False
if
stored_port
is
not
None
:
logging
.
warning
(
f
"[
{
self
.
resource_name
}
] Stale port file: port
{
stored_port
}
not in use, starting fresh"
)
self
.
port
=
allocate_port
(
self
.
start_port
)
self
.
_write_port
(
self
.
port
)
self
.
_server
=
self
.
_create_server
(
self
.
port
)
self
.
_server
.
__enter__
()
logging
.
info
(
logging
.
info
(
f
"[
{
self
.
resource_name
}
]
Reusing existing process (ref_count=
{
ref_coun
t
}
)
"
f
"[
{
self
.
resource_name
}
]
Started process on port
{
self
.
por
t
}
"
)
)
return
self
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
with
FileLock
(
self
.
lock_file
):
# Never tear down - let the process live until the container dies.
ref_count
=
self
.
_decrement_ref_count
()
# This avoids race conditions and simplifies the logic.
if
ref_count
==
0
and
self
.
_owns_process
:
pass
# Last reference - stop the process
if
self
.
_server
:
self
.
_server
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
logging
.
info
(
f
"[
{
self
.
resource_name
}
] Stopped process (ref_count=0)"
)
elif
ref_count
==
0
:
# Last reference but we don't own it - shouldn't happen, but clean up ref file
if
self
.
ref_file
.
exists
():
self
.
ref_file
.
unlink
()
logging
.
warning
(
f
"[
{
self
.
resource_name
}
] Ref count reached 0 but we don't own process"
)
else
:
logging
.
info
(
f
"[
{
self
.
resource_name
}
] Released reference (ref_count=
{
ref_count
}
)"
)
class
SharedEtcdServer
(
SharedManagedProcess
):
class
SharedEtcdServer
(
SharedManagedProcess
):
"""EtcdServer with file-based reference counting for multi-process sharing."""
"""EtcdServer with file-based reference counting for multi-process sharing."""
def
__init__
(
self
,
request
,
tmp_path_factory
,
port
=
23
79
,
timeout
=
300
):
def
__init__
(
self
,
request
,
tmp_path_factory
,
start_
port
=
23
80
,
timeout
=
300
):
super
().
__init__
(
request
,
tmp_path_factory
,
"etcd"
,
port
,
timeout
)
super
().
__init__
(
request
,
tmp_path_factory
,
"etcd"
,
start_
port
,
timeout
)
# Create a log directory for session-scoped servers
# Create a log directory for session-scoped servers
self
.
_log_dir
=
tempfile
.
mkdtemp
(
prefix
=
f
"pytest_
{
self
.
resource_name
}
_logs_"
)
self
.
_log_dir
=
tempfile
.
mkdtemp
(
prefix
=
f
"pytest_
{
self
.
resource_name
}
_logs_"
)
def
_create_server
(
self
)
->
ManagedProcess
:
def
_create_server
(
self
,
port
:
int
)
->
ManagedProcess
:
"""Create EtcdServer instance."""
"""Create EtcdServer instance."""
server
=
EtcdServer
(
self
.
request
,
port
=
self
.
port
,
timeout
=
self
.
timeout
)
server
=
EtcdServer
(
self
.
request
,
port
=
port
,
timeout
=
self
.
timeout
)
# Override log_dir since request.node.name is empty in session scope
# Override log_dir since request.node.name is empty in session scope
server
.
log_dir
=
self
.
_log_dir
server
.
log_dir
=
self
.
_log_dir
return
server
return
server
...
@@ -486,14 +523,27 @@ class SharedEtcdServer(SharedManagedProcess):
...
@@ -486,14 +523,27 @@ class SharedEtcdServer(SharedManagedProcess):
class
SharedNatsServer
(
SharedManagedProcess
):
class
SharedNatsServer
(
SharedManagedProcess
):
"""NatsServer with file-based reference counting for multi-process sharing."""
"""NatsServer with file-based reference counting for multi-process sharing."""
def
__init__
(
self
,
request
,
tmp_path_factory
,
port
=
4222
,
timeout
=
300
):
def
__init__
(
super
().
__init__
(
request
,
tmp_path_factory
,
"nats"
,
port
,
timeout
)
self
,
request
,
tmp_path_factory
,
start_port
=
4223
,
timeout
=
300
,
disable_jetstream
=
False
,
):
super
().
__init__
(
request
,
tmp_path_factory
,
"nats"
,
start_port
,
timeout
)
# Create a log directory for session-scoped servers
# Create a log directory for session-scoped servers
self
.
_log_dir
=
tempfile
.
mkdtemp
(
prefix
=
f
"pytest_
{
self
.
resource_name
}
_logs_"
)
self
.
_log_dir
=
tempfile
.
mkdtemp
(
prefix
=
f
"pytest_
{
self
.
resource_name
}
_logs_"
)
self
.
_disable_jetstream
=
disable_jetstream
def
_create_server
(
self
)
->
ManagedProcess
:
def
_create_server
(
self
,
port
:
int
)
->
ManagedProcess
:
"""Create NatsServer instance."""
"""Create NatsServer instance."""
server
=
NatsServer
(
self
.
request
,
port
=
self
.
port
,
timeout
=
self
.
timeout
)
server
=
NatsServer
(
self
.
request
,
port
=
port
,
timeout
=
self
.
timeout
,
disable_jetstream
=
self
.
_disable_jetstream
,
)
# Override log_dir since request.node.name is empty in session scope
# Override log_dir since request.node.name is empty in session scope
server
.
log_dir
=
self
.
_log_dir
server
.
log_dir
=
self
.
_log_dir
return
server
return
server
...
@@ -525,6 +575,27 @@ def request_plane(request):
...
@@ -525,6 +575,27 @@ def request_plane(request):
return
getattr
(
request
,
"param"
,
"nats"
)
return
getattr
(
request
,
"param"
,
"nats"
)
@
pytest
.
fixture
def
use_nats_core
(
request
):
"""
Whether to use NATS Core mode (local indexer) instead of JetStream. Defaults to False.
When True:
- NATS server starts without JetStream (-js flag omitted) for faster startup
- Tests should use enable_local_indexer=True in mocker_args
When False (default):
- NATS server starts with JetStream for KV event distribution
- Tests use JetStream-based indexer synchronization
To use NATS Core mode:
@pytest.mark.parametrize("use_nats_core", [True], indirect=True)
def test_example(runtime_services_dynamic_ports):
...
"""
return
getattr
(
request
,
"param"
,
False
)
@
pytest
.
fixture
()
@
pytest
.
fixture
()
def
runtime_services
(
request
,
store_kv
,
request_plane
):
def
runtime_services
(
request
,
store_kv
,
request_plane
):
"""
"""
...
@@ -551,7 +622,7 @@ def runtime_services(request, store_kv, request_plane):
...
@@ -551,7 +622,7 @@ def runtime_services(request, store_kv, request_plane):
@
pytest
.
fixture
()
@
pytest
.
fixture
()
def
runtime_services_dynamic_ports
(
request
,
store_kv
,
request_plane
):
def
runtime_services_dynamic_ports
(
request
,
store_kv
,
request_plane
,
use_nats_core
):
"""Provide NATS and Etcd servers with truly dynamic ports per test.
"""Provide NATS and Etcd servers with truly dynamic ports per test.
This fixture actually allocates dynamic ports by passing port=0 to the servers.
This fixture actually allocates dynamic ports by passing port=0 to the servers.
...
@@ -566,6 +637,7 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane):
...
@@ -566,6 +637,7 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane):
- If store_kv != "etcd", etcd is not started (returns None)
- If store_kv != "etcd", etcd is not started (returns None)
- NATS is always started when etcd is used, because KV events require NATS
- NATS is always started when etcd is used, because KV events require NATS
regardless of the request_plane (tcp/nats only affects request transport)
regardless of the request_plane (tcp/nats only affects request transport)
- JetStream is enabled by default; disabled when use_nats_core=True for faster startup
Returns a tuple of (nats_process, etcd_process) where each has a .port attribute.
Returns a tuple of (nats_process, etcd_process) where each has a .port attribute.
"""
"""
...
@@ -573,24 +645,42 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane):
...
@@ -573,24 +645,42 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane):
# Port cleanup is now handled in NatsServer and EtcdServer __exit__ methods
# Port cleanup is now handled in NatsServer and EtcdServer __exit__ methods
# Always start NATS when etcd is used - KV events require NATS regardless of request_plane
# Always start NATS when etcd is used - KV events require NATS regardless of request_plane
# When use_nats_core=True, disable JetStream for faster startup
if
store_kv
==
"etcd"
:
if
store_kv
==
"etcd"
:
with
NatsServer
(
request
,
port
=
0
)
as
nats_process
:
with
NatsServer
(
request
,
port
=
0
,
disable_jetstream
=
use_nats_core
)
as
nats_process
:
with
EtcdServer
(
request
,
port
=
0
)
as
etcd_process
:
with
EtcdServer
(
request
,
port
=
0
)
as
etcd_process
:
# Set environment variables for Rust/Python runtime to use. Note that xdist (parallel execution)
# Save original env vars (may be set by session-scoped fixture)
# will launch isolated tests in a new process, so no need to worry about environment pollution.
orig_nats
=
os
.
environ
.
get
(
"NATS_SERVER"
)
orig_etcd
=
os
.
environ
.
get
(
"ETCD_ENDPOINTS"
)
# Set environment variables for this test's dynamic ports
os
.
environ
[
"NATS_SERVER"
]
=
f
"nats://localhost:
{
nats_process
.
port
}
"
os
.
environ
[
"NATS_SERVER"
]
=
f
"nats://localhost:
{
nats_process
.
port
}
"
os
.
environ
[
"ETCD_ENDPOINTS"
]
=
f
"http://localhost:
{
etcd_process
.
port
}
"
os
.
environ
[
"ETCD_ENDPOINTS"
]
=
f
"http://localhost:
{
etcd_process
.
port
}
"
yield
nats_process
,
etcd_process
yield
nats_process
,
etcd_process
# No test should rely on these variables after the test, but clean up just in case.
# Restore original env vars (or remove if they weren't set)
os
.
environ
.
pop
(
"NATS_SERVER"
,
None
)
if
orig_nats
is
not
None
:
os
.
environ
.
pop
(
"ETCD_ENDPOINTS"
,
None
)
os
.
environ
[
"NATS_SERVER"
]
=
orig_nats
else
:
os
.
environ
.
pop
(
"NATS_SERVER"
,
None
)
if
orig_etcd
is
not
None
:
os
.
environ
[
"ETCD_ENDPOINTS"
]
=
orig_etcd
else
:
os
.
environ
.
pop
(
"ETCD_ENDPOINTS"
,
None
)
elif
request_plane
==
"nats"
:
elif
request_plane
==
"nats"
:
with
NatsServer
(
request
,
port
=
0
)
as
nats_process
:
with
NatsServer
(
request
,
port
=
0
,
disable_jetstream
=
use_nats_core
)
as
nats_process
:
orig_nats
=
os
.
environ
.
get
(
"NATS_SERVER"
)
os
.
environ
[
"NATS_SERVER"
]
=
f
"nats://localhost:
{
nats_process
.
port
}
"
os
.
environ
[
"NATS_SERVER"
]
=
f
"nats://localhost:
{
nats_process
.
port
}
"
yield
nats_process
,
None
yield
nats_process
,
None
os
.
environ
.
pop
(
"NATS_SERVER"
,
None
)
if
orig_nats
is
not
None
:
os
.
environ
[
"NATS_SERVER"
]
=
orig_nats
else
:
os
.
environ
.
pop
(
"NATS_SERVER"
,
None
)
else
:
else
:
yield
None
,
None
yield
None
,
None
...
@@ -599,22 +689,28 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane):
...
@@ -599,22 +689,28 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane):
def
runtime_services_session
(
request
,
tmp_path_factory
):
def
runtime_services_session
(
request
,
tmp_path_factory
):
"""Session-scoped fixture that provides shared NATS and etcd instances for all tests.
"""Session-scoped fixture that provides shared NATS and etcd instances for all tests.
Uses file-based reference counting to coordinate between pytest-xdist worker processes.
Uses file locking to coordinate between pytest-xdist worker processes.
Only the first worker starts services, and only the last worker tears them down.
First worker starts services on dynamic ports, subsequent workers reuse them.
Services are never torn down (live until container dies) to avoid race conditions.
WARNING: may not be parallel/xdist safe.
This fixture is xdist-safe when tests use unique namespaces (e.g. random suffixes)
- This fixture shares one NATS + one etcd across many tests (and across xdist workers).
and do not assume exclusive access to global streams/keys.
- It is only safe if tests fully isolate state (e.g. unique namespaces) and do not
assume exclusive access to global streams/keys/ports.
- Prefer `runtime_services_dynamic_ports` for true per-test isolation in parallel runs.
TODO: once nothing
use
s
`runtime_services_
session`, make the per-test
dynamic
ports
For tests that need to restart NATS (e.g. indexer sync),
use `runtime_services_dynamic
_
ports
`
behavior the default for router/fron
te
n
d in
tegration test
s.
which provides per-test isola
ted in
stance
s.
"""
"""
with
SharedNatsServer
(
request
,
tmp_path_factory
)
as
nats
:
with
SharedNatsServer
(
request
,
tmp_path_factory
)
as
nats
:
with
SharedEtcdServer
(
request
,
tmp_path_factory
)
as
etcd
:
with
SharedEtcdServer
(
request
,
tmp_path_factory
)
as
etcd
:
# Set environment variables for Rust/Python runtime to use
os
.
environ
[
"NATS_SERVER"
]
=
f
"nats://localhost:
{
nats
.
port
}
"
os
.
environ
[
"ETCD_ENDPOINTS"
]
=
f
"http://localhost:
{
etcd
.
port
}
"
yield
nats
,
etcd
yield
nats
,
etcd
# Clean up environment variables
os
.
environ
.
pop
(
"NATS_SERVER"
,
None
)
os
.
environ
.
pop
(
"ETCD_ENDPOINTS"
,
None
)
@
pytest
.
fixture
@
pytest
.
fixture
def
file_storage_backend
():
def
file_storage_backend
():
...
...
tests/router/common.py
View file @
584020f4
...
@@ -525,10 +525,10 @@ async def send_request_via_python_kv_router(
...
@@ -525,10 +525,10 @@ async def send_request_via_python_kv_router(
stream
=
await
kv_python_router
.
generate
(
stream
=
await
kv_python_router
.
generate
(
token_ids
=
token_ids
,
token_ids
=
token_ids
,
model
=
model_name
,
model
=
model_name
,
stop_conditions
=
stop_conditions
,
stop_conditions
=
stop_conditions
,
# type: ignore[arg-type]
sampling_options
=
sampling_options
,
sampling_options
=
sampling_options
,
# type: ignore[arg-type]
output_options
=
output_options
,
output_options
=
output_options
,
# type: ignore[arg-type]
router_config_override
=
router_config_override
,
router_config_override
=
router_config_override
,
# type: ignore[arg-type]
worker_id
=
worker_id
,
worker_id
=
worker_id
,
dp_rank
=
dp_rank
,
dp_rank
=
dp_rank
,
)
)
...
@@ -693,6 +693,7 @@ def _test_router_two_routers(
...
@@ -693,6 +693,7 @@ def _test_router_two_routers(
test_payload
:
dict
,
test_payload
:
dict
,
num_requests
:
int
,
num_requests
:
int
,
store_backend
:
str
=
"etcd"
,
store_backend
:
str
=
"etcd"
,
skip_consumer_verification
:
bool
=
False
,
):
):
"""Test two KV routers with alternating requests and consumer lifecycle verification.
"""Test two KV routers with alternating requests and consumer lifecycle verification.
...
@@ -701,8 +702,8 @@ def _test_router_two_routers(
...
@@ -701,8 +702,8 @@ def _test_router_two_routers(
This test:
This test:
1. Starts two KV routers on different ports
1. Starts two KV routers on different ports
2. Sends requests alternating between the two routers
2. Sends requests alternating between the two routers
3. Verifies that both routers create durable consumers
3. Verifies that both routers create durable consumers
(unless skipped)
4. Verifies consumers are cleaned up when routers exit
4. Verifies consumers are cleaned up when routers exit
(unless skipped)
Args:
Args:
engine_workers: Backend workers (mocker/vllm) already initialized with __enter__()
engine_workers: Backend workers (mocker/vllm) already initialized with __enter__()
...
@@ -712,6 +713,7 @@ def _test_router_two_routers(
...
@@ -712,6 +713,7 @@ def _test_router_two_routers(
test_payload: Test payload to send to /v1/chat/completions
test_payload: Test payload to send to /v1/chat/completions
num_requests: Number of concurrent requests to send
num_requests: Number of concurrent requests to send
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
skip_consumer_verification: Skip JetStream consumer verification (for NATS Core mode).
Raises:
Raises:
AssertionError: If consumer lifecycle verification fails
AssertionError: If consumer lifecycle verification fails
...
@@ -846,8 +848,14 @@ def _test_router_two_routers(
...
@@ -846,8 +848,14 @@ def _test_router_two_routers(
finally
:
finally
:
await
nc
.
close
()
await
nc
.
close
()
# Run consumer lifecycle verification
# Run consumer lifecycle verification (skip for NATS Core mode)
asyncio
.
run
(
verify_consumer_lifecycle
())
if
skip_consumer_verification
:
logger
.
info
(
"Skipping JetStream consumer verification (NATS Core mode)"
)
# Clean up routers manually since we're not doing consumer verification
for
kv_router
in
kv_routers
:
kv_router
.
__exit__
(
None
,
None
,
None
)
else
:
asyncio
.
run
(
verify_consumer_lifecycle
())
# Clear the kv_routers list since we've already cleaned them up
# Clear the kv_routers list since we've already cleaned them up
kv_routers
=
[]
kv_routers
=
[]
...
...
tests/router/test_router_e2e_with_mockers.py
View file @
584020f4
...
@@ -323,8 +323,15 @@ class DisaggMockerProcess:
...
@@ -323,8 +323,15 @@ class DisaggMockerProcess:
@
pytest
.
mark
.
timeout
(
42
)
# ~3x average (~13.80s), rounded up
@
pytest
.
mark
.
timeout
(
42
)
# ~3x average (~13.80s), rounded up
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"use_nats_core"
,
[
True
],
indirect
=
True
)
# Use NATS Core (local indexer)
def
test_mocker_kv_router
(
def
test_mocker_kv_router
(
request
,
runtime_services_dynamic_ports
,
predownload_tokenizers
,
request_plane
request
,
runtime_services_dynamic_ports
,
predownload_tokenizers
,
request_plane
,
use_nats_core
,
):
):
"""
"""
Test KV router with multiple mocker engine instances.
Test KV router with multiple mocker engine instances.
...
@@ -335,8 +342,12 @@ def test_mocker_kv_router(
...
@@ -335,8 +342,12 @@ def test_mocker_kv_router(
# runtime_services starts etcd and optionally nats based on request_plane
# runtime_services starts etcd and optionally nats based on request_plane
logger
.
info
(
f
"Starting mocker KV router test with request_plane=
{
request_plane
}
"
)
logger
.
info
(
f
"Starting mocker KV router test with request_plane=
{
request_plane
}
"
)
# Create mocker args dictionary
# Create mocker args dictionary - use local indexer (NATS Core mode)
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
}
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
,
"enable_local_indexer"
:
use_nats_core
,
}
try
:
try
:
# Start mocker instances with the new CLI interface
# Start mocker instances with the new CLI interface
...
@@ -372,6 +383,9 @@ def test_mocker_kv_router(
...
@@ -372,6 +383,9 @@ def test_mocker_kv_router(
@
pytest
.
mark
.
parametrize
(
"store_backend"
,
[
"etcd"
,
"file"
])
@
pytest
.
mark
.
parametrize
(
"store_backend"
,
[
"etcd"
,
"file"
])
@
pytest
.
mark
.
parametrize
(
"use_nats_core"
,
[
True
],
indirect
=
True
)
# Use NATS Core (local indexer)
@
pytest
.
mark
.
timeout
(
60
)
# ~3x average (~19.86s), rounded up
@
pytest
.
mark
.
timeout
(
60
)
# ~3x average (~19.86s), rounded up
def
test_mocker_two_kv_router
(
def
test_mocker_two_kv_router
(
request
,
request
,
...
@@ -379,6 +393,7 @@ def test_mocker_two_kv_router(
...
@@ -379,6 +393,7 @@ def test_mocker_two_kv_router(
predownload_tokenizers
,
predownload_tokenizers
,
file_storage_backend
,
file_storage_backend
,
store_backend
,
store_backend
,
use_nats_core
,
):
):
"""
"""
Test with two KV routers and multiple mocker engine instances.
Test with two KV routers and multiple mocker engine instances.
...
@@ -391,8 +406,12 @@ def test_mocker_two_kv_router(
...
@@ -391,8 +406,12 @@ def test_mocker_two_kv_router(
f
"Starting mocker two KV router test with
{
store_backend
}
storage backend"
f
"Starting mocker two KV router test with
{
store_backend
}
storage backend"
)
)
# Create mocker args dictionary
# Create mocker args dictionary - use local indexer (NATS Core mode)
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
}
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
,
"enable_local_indexer"
:
use_nats_core
,
}
try
:
try
:
# Start mocker instances with the new CLI interface
# Start mocker instances with the new CLI interface
...
@@ -420,6 +439,7 @@ def test_mocker_two_kv_router(
...
@@ -420,6 +439,7 @@ def test_mocker_two_kv_router(
test_payload
=
TEST_PAYLOAD
,
test_payload
=
TEST_PAYLOAD
,
num_requests
=
NUM_REQUESTS
,
num_requests
=
NUM_REQUESTS
,
store_backend
=
store_backend
,
store_backend
=
store_backend
,
skip_consumer_verification
=
use_nats_core
,
# Skip JetStream checks in NATS Core mode
)
)
finally
:
finally
:
...
@@ -428,17 +448,21 @@ def test_mocker_two_kv_router(
...
@@ -428,17 +448,21 @@ def test_mocker_two_kv_router(
@
pytest
.
mark
.
skip
(
reason
=
"Flaky, temporarily disabled"
)
@
pytest
.
mark
.
skip
(
reason
=
"Flaky, temporarily disabled"
)
@
pytest
.
mark
.
parametrize
(
"use_nats_core"
,
[
True
],
indirect
=
True
)
# Use NATS Core (local indexer)
@
pytest
.
mark
.
timeout
(
60
)
# ~3x average (~19.86s), rounded up (when enabled)
@
pytest
.
mark
.
timeout
(
60
)
# ~3x average (~19.86s), rounded up (when enabled)
def
test_mocker_kv_router_overload_503
(
def
test_mocker_kv_router_overload_503
(
request
,
runtime_services_dynamic_ports
,
predownload_tokenizers
request
,
runtime_services_dynamic_ports
,
predownload_tokenizers
,
use_nats_core
):
):
"""Test that KV router returns 503 when mocker workers are overloaded."""
"""Test that KV router returns 503 when mocker workers are overloaded."""
logger
.
info
(
"Starting mocker KV router overload test for 503 status"
)
logger
.
info
(
"Starting mocker KV router overload test for 503 status"
)
# Create mocker args dictionary with limited resources
# Create mocker args dictionary with limited resources
- use local indexer (NATS Core mode)
mocker_args
=
{
mocker_args
=
{
"speedup_ratio"
:
10
,
"speedup_ratio"
:
10
,
"block_size"
:
4
,
# Smaller block size
"block_size"
:
4
,
# Smaller block size
"num_gpu_blocks"
:
64
,
# Limited GPU blocks to exhaust quickly
"num_gpu_blocks"
:
64
,
# Limited GPU blocks to exhaust quickly
"enable_local_indexer"
:
use_nats_core
,
}
}
try
:
try
:
...
@@ -468,12 +492,24 @@ def test_mocker_kv_router_overload_503(
...
@@ -468,12 +492,24 @@ def test_mocker_kv_router_overload_503(
@
pytest
.
mark
.
timeout
(
22
)
# ~3x average (~7.10s), rounded up
@
pytest
.
mark
.
timeout
(
22
)
# ~3x average (~7.10s), rounded up
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"use_nats_core"
,
[
True
],
indirect
=
True
)
# Use NATS Core (local indexer)
def
test_kv_push_router_bindings
(
def
test_kv_push_router_bindings
(
request
,
runtime_services_dynamic_ports
,
predownload_tokenizers
,
request_plane
request
,
runtime_services_dynamic_ports
,
predownload_tokenizers
,
request_plane
,
use_nats_core
,
):
):
"""Test KvPushRouter Python bindings with mocker engines."""
"""Test KvPushRouter Python bindings with mocker engines."""
logger
.
info
(
"Starting KvPushRouter bindings test"
)
logger
.
info
(
"Starting KvPushRouter bindings test"
)
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
}
# Use local indexer (NATS Core mode)
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
,
"enable_local_indexer"
:
use_nats_core
,
}
try
:
try
:
# Start mocker instances
# Start mocker instances
...
@@ -507,19 +543,19 @@ def test_kv_push_router_bindings(
...
@@ -507,19 +543,19 @@ def test_kv_push_router_bindings(
mockers
.
__exit__
(
None
,
None
,
None
)
mockers
.
__exit__
(
None
,
None
,
None
)
# NO @pytest.mark.parallel - nats_core variant stops/restarts NATS
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"store_backend,use_nats_core,request_plane"
,
"store_backend,use_nats_core,request_plane"
,
[
[
(
"etcd"
,
False
,
"nats"
),
# JetStream mode
(
"etcd"
,
False
,
"nats"
),
# JetStream mode
- uses JetStream (default)
(
"etcd"
,
True
,
"tcp"
),
# NATS core mode (with gap detection)
(
"etcd"
,
True
,
"tcp"
),
# NATS core mode (with gap detection)
- no JetStream
(
"file"
,
False
,
"nats"
),
# File backend
(
"file"
,
False
,
"nats"
),
# File backend
- uses JetStream (default)
],
],
ids
=
[
ids
=
[
"jetstream"
,
"jetstream"
,
"nats_core"
,
"nats_core"
,
"file"
,
"file"
,
],
],
indirect
=
[
"request_plane"
,
"use_nats_core"
],
)
)
@
pytest
.
mark
.
timeout
(
90
)
# TODO: figure out a timeout
@
pytest
.
mark
.
timeout
(
90
)
# TODO: figure out a timeout
def
test_indexers_sync
(
def
test_indexers_sync
(
...
@@ -590,12 +626,20 @@ def test_indexers_sync(
...
@@ -590,12 +626,20 @@ def test_indexers_sync(
@
pytest
.
mark
.
timeout
(
42
)
# ~3x average (~13.80s), rounded up
@
pytest
.
mark
.
timeout
(
42
)
# ~3x average (~13.80s), rounded up
@
pytest
.
mark
.
parametrize
(
"use_nats_core"
,
[
True
],
indirect
=
True
)
# Use NATS Core (local indexer)
def
test_query_instance_id_returns_worker_and_tokens
(
def
test_query_instance_id_returns_worker_and_tokens
(
request
,
runtime_services_dynamic_ports
,
predownload_tokenizers
request
,
runtime_services_dynamic_ports
,
predownload_tokenizers
,
use_nats_core
):
):
"""Test query_instance_id annotation with mocker engines."""
"""Test query_instance_id annotation with mocker engines."""
logger
.
info
(
"Starting KV router query_instance_id annotation test"
)
logger
.
info
(
"Starting KV router query_instance_id annotation test"
)
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
}
# Use local indexer (NATS Core mode)
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
,
"enable_local_indexer"
:
use_nats_core
,
}
os
.
makedirs
(
request
.
node
.
name
,
exist_ok
=
True
)
os
.
makedirs
(
request
.
node
.
name
,
exist_ok
=
True
)
try
:
try
:
...
@@ -629,11 +673,12 @@ def test_query_instance_id_returns_worker_and_tokens(
...
@@ -629,11 +673,12 @@ def test_query_instance_id_returns_worker_and_tokens(
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"use_nats_core,use_kv_events"
,
"use_nats_core,use_kv_events"
,
[
[
(
False
,
True
),
# JetStream mode (default)
(
False
,
True
),
# JetStream mode (default)
- uses JetStream
(
True
,
True
),
# NATS Core + local indexer mode
(
True
,
True
),
# NATS Core + local indexer mode
- no JetStream
(
False
,
False
),
# Approximate mode (--no-kv-events)
(
False
,
False
),
# Approximate mode (--no-kv-events)
- uses JetStream
],
],
ids
=
[
"jetstream"
,
"nats_core"
,
"no_kv_events"
],
ids
=
[
"jetstream"
,
"nats_core"
,
"no_kv_events"
],
indirect
=
[
"use_nats_core"
],
)
)
def
test_router_decisions
(
def
test_router_decisions
(
request
,
request
,
...
@@ -828,9 +873,16 @@ def test_router_decisions_disagg(
...
@@ -828,9 +873,16 @@ def test_router_decisions_disagg(
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"use_nats_core"
,
[
True
],
indirect
=
True
)
# Use NATS Core (local indexer)
@
pytest
.
mark
.
timeout
(
39
)
# ~3x average (~12.84s), rounded up
@
pytest
.
mark
.
timeout
(
39
)
# ~3x average (~12.84s), rounded up
def
test_busy_threshold_endpoint
(
def
test_busy_threshold_endpoint
(
request
,
runtime_services_dynamic_ports
,
predownload_tokenizers
,
request_plane
request
,
runtime_services_dynamic_ports
,
predownload_tokenizers
,
request_plane
,
use_nats_core
,
):
):
"""Test that the /busy_threshold endpoint can be hit and responds correctly.
"""Test that the /busy_threshold endpoint can be hit and responds correctly.
...
@@ -846,7 +898,12 @@ def test_busy_threshold_endpoint(
...
@@ -846,7 +898,12 @@ def test_busy_threshold_endpoint(
f
"Starting busy_threshold endpoint test with request_plane=
{
request_plane
}
"
f
"Starting busy_threshold endpoint test with request_plane=
{
request_plane
}
"
)
)
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
}
# Use local indexer (NATS Core mode)
mocker_args
=
{
"speedup_ratio"
:
SPEEDUP_RATIO
,
"block_size"
:
BLOCK_SIZE
,
"enable_local_indexer"
:
use_nats_core
,
}
try
:
try
:
logger
.
info
(
f
"Starting
{
NUM_MOCKERS
}
mocker instances"
)
logger
.
info
(
f
"Starting
{
NUM_MOCKERS
}
mocker instances"
)
...
...
tests/router/test_router_e2e_with_sglang.py
View file @
584020f4
...
@@ -87,6 +87,7 @@ class SGLangProcess:
...
@@ -87,6 +87,7 @@ class SGLangProcess:
data_parallel_size
:
Optional
[
int
]
=
None
,
data_parallel_size
:
Optional
[
int
]
=
None
,
request_plane
:
str
=
"tcp"
,
request_plane
:
str
=
"tcp"
,
store_backend
:
str
=
"etcd"
,
store_backend
:
str
=
"etcd"
,
enable_local_indexer
:
bool
=
False
,
):
):
"""Initialize SGLang workers with dynamo integration.
"""Initialize SGLang workers with dynamo integration.
...
@@ -103,6 +104,7 @@ class SGLangProcess:
...
@@ -103,6 +104,7 @@ class SGLangProcess:
data_parallel_size: If set, enables data parallelism with this many ranks (num_workers must equal data_parallel_size)
data_parallel_size: If set, enables data parallelism with this many ranks (num_workers must equal data_parallel_size)
request_plane: Request plane to use ("nats", "tcp", or "http"). Defaults to "tcp".
request_plane: Request plane to use ("nats", "tcp", or "http"). Defaults to "tcp".
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
enable_local_indexer: If True, enable worker-local KV indexer for NATS Core mode. Defaults to False.
"""
"""
# Generate unique namespace for isolation
# Generate unique namespace for isolation
namespace_suffix
=
generate_random_suffix
()
namespace_suffix
=
generate_random_suffix
()
...
@@ -192,6 +194,10 @@ class SGLangProcess:
...
@@ -192,6 +194,10 @@ class SGLangProcess:
if
self
.
store_backend
==
"file"
and
"DYN_FILE_KV"
in
os
.
environ
:
if
self
.
store_backend
==
"file"
and
"DYN_FILE_KV"
in
os
.
environ
:
env_vars
[
"DYN_FILE_KV"
]
=
os
.
environ
[
"DYN_FILE_KV"
]
env_vars
[
"DYN_FILE_KV"
]
=
os
.
environ
[
"DYN_FILE_KV"
]
# Enable local indexer for NATS Core mode
if
enable_local_indexer
:
env_vars
[
"DYN_LOCAL_INDEXER"
]
=
"true"
env
.
update
(
env_vars
)
env
.
update
(
env_vars
)
# Create managed process for the worker
# Create managed process for the worker
...
@@ -313,7 +319,7 @@ class SGLangProcess:
...
@@ -313,7 +319,7 @@ class SGLangProcess:
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
timeout
(
150
)
# ~3x average (~46s/test), rounded up
@
pytest
.
mark
.
timeout
(
150
)
# ~3x average (~46s/test), rounded up
def
test_sglang_kv_router_basic
(
def
test_sglang_kv_router_basic
(
request
,
request
,
...
@@ -369,7 +375,7 @@ def test_sglang_kv_router_basic(
...
@@ -369,7 +375,7 @@ def test_sglang_kv_router_basic(
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
skip
(
reason
=
"Broken by sglang changes"
)
@
pytest
.
mark
.
skip
(
reason
=
"Broken by sglang changes"
)
# TODO: Re-enable this test once https://github.com/sgl-project/sglang/pull/14934 is merged
# TODO: Re-enable this test once https://github.com/sgl-project/sglang/pull/14934 is merged
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"tcp"
],
indirect
=
True
)
def
test_router_decisions_sglang_multiple_workers
(
def
test_router_decisions_sglang_multiple_workers
(
request
,
request
,
runtime_services_dynamic_ports
,
runtime_services_dynamic_ports
,
...
@@ -414,7 +420,7 @@ def test_router_decisions_sglang_multiple_workers(
...
@@ -414,7 +420,7 @@ def test_router_decisions_sglang_multiple_workers(
@
pytest
.
mark
.
gpu_2
@
pytest
.
mark
.
gpu_2
@
pytest
.
mark
.
post_merge
@
pytest
.
mark
.
post_merge
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
timeout
(
600
)
# 10 min max (multi-GPU + DP startup variance)
@
pytest
.
mark
.
timeout
(
600
)
# 10 min max (multi-GPU + DP startup variance)
def
test_router_decisions_sglang_dp
(
def
test_router_decisions_sglang_dp
(
request
,
request
,
...
@@ -469,10 +475,10 @@ def test_router_decisions_sglang_dp(
...
@@ -469,10 +475,10 @@ def test_router_decisions_sglang_dp(
"store_backend,use_nats_core,request_plane"
,
"store_backend,use_nats_core,request_plane"
,
[
[
(
"etcd"
,
False
,
"nats"
),
# JetStream mode
(
"etcd"
,
False
,
"nats"
),
# JetStream mode
# ("etcd", True, "tcp"), #
ignored, needs unconditional nats_client
# ("etcd", True, "tcp"), #
nats_core mode - disabled for now
# ("file", False, "nats"), # File backend - TODO: investigate file backend support for SGLang
# ("file", False, "nats"), # File backend - TODO: investigate file backend support for SGLang
],
],
ids
=
[
"jetstream"
],
# "nats_core" and "file" commented out
ids
=
[
"jetstream"
],
)
)
@
pytest
.
mark
.
timeout
(
150
)
# ~3x average (~46s/test), rounded up
@
pytest
.
mark
.
timeout
(
150
)
# ~3x average (~46s/test), rounded up
def
test_sglang_indexers_sync
(
def
test_sglang_indexers_sync
(
...
@@ -491,8 +497,11 @@ def test_sglang_indexers_sync(
...
@@ -491,8 +497,11 @@ def test_sglang_indexers_sync(
Tests with configuration:
Tests with configuration:
- jetstream: etcd backend, JetStream for KV events, NATS request plane
- jetstream: etcd backend, JetStream for KV events, NATS request plane
- tcp_nats_core: etcd backend, local indexer with NATS Core, TCP request plane
"""
"""
# runtime_services_dynamic_ports handles NATS and etcd startup
# runtime_services_dynamic_ports handles NATS and etcd startup
nats_process
,
_etcd_process
=
runtime_services_dynamic_ports
logger
.
info
(
logger
.
info
(
f
"Starting SGLang indexers sync test: store_backend=
{
store_backend
}
, "
f
"Starting SGLang indexers sync test: store_backend=
{
store_backend
}
, "
f
"use_nats_core=
{
use_nats_core
}
, request_plane=
{
request_plane
}
"
f
"use_nats_core=
{
use_nats_core
}
, request_plane=
{
request_plane
}
"
...
@@ -510,6 +519,7 @@ def test_sglang_indexers_sync(
...
@@ -510,6 +519,7 @@ def test_sglang_indexers_sync(
single_gpu
=
True
,
# fit workers into one GPU
single_gpu
=
True
,
# fit workers into one GPU
request_plane
=
request_plane
,
request_plane
=
request_plane
,
store_backend
=
store_backend
,
store_backend
=
store_backend
,
enable_local_indexer
=
use_nats_core
,
)
)
logger
.
info
(
f
"All SGLang workers using namespace:
{
sglang_workers
.
namespace
}
"
)
logger
.
info
(
f
"All SGLang workers using namespace:
{
sglang_workers
.
namespace
}
"
)
sglang_workers
.
__enter__
()
sglang_workers
.
__enter__
()
...
@@ -523,6 +533,8 @@ def test_sglang_indexers_sync(
...
@@ -523,6 +533,8 @@ def test_sglang_indexers_sync(
num_workers
=
N_SGLANG_WORKERS
,
num_workers
=
N_SGLANG_WORKERS
,
store_backend
=
store_backend
,
store_backend
=
store_backend
,
request_plane
=
request_plane
,
request_plane
=
request_plane
,
test_nats_interruption
=
use_nats_core
,
nats_server
=
nats_process
if
use_nats_core
else
None
,
)
)
logger
.
info
(
"SGLang indexers sync test completed successfully"
)
logger
.
info
(
"SGLang indexers sync test completed successfully"
)
...
...
tests/router/test_router_e2e_with_trtllm.py
View file @
584020f4
...
@@ -84,6 +84,7 @@ class TRTLLMProcess:
...
@@ -84,6 +84,7 @@ class TRTLLMProcess:
single_gpu
:
bool
=
False
,
single_gpu
:
bool
=
False
,
request_plane
:
str
=
"tcp"
,
request_plane
:
str
=
"tcp"
,
store_backend
:
str
=
"etcd"
,
store_backend
:
str
=
"etcd"
,
enable_local_indexer
:
bool
=
False
,
):
):
"""Initialize TRT-LLM workers with dynamo integration.
"""Initialize TRT-LLM workers with dynamo integration.
...
@@ -98,6 +99,7 @@ class TRTLLMProcess:
...
@@ -98,6 +99,7 @@ class TRTLLMProcess:
single_gpu: If True, all workers share GPU 0
single_gpu: If True, all workers share GPU 0
request_plane: Request plane to use ("nats", "tcp", or "http"). Defaults to "tcp".
request_plane: Request plane to use ("nats", "tcp", or "http"). Defaults to "tcp".
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
enable_local_indexer: If True, enable worker-local KV indexer for NATS Core mode. Defaults to False.
Note: TRT-LLM doesn't support data parallelism like vLLM (dp_rank is always 0).
Note: TRT-LLM doesn't support data parallelism like vLLM (dp_rank is always 0).
Tensor parallelism (TP) is supported but creates 1 worker spanning multiple GPUs,
Tensor parallelism (TP) is supported but creates 1 worker spanning multiple GPUs,
...
@@ -172,6 +174,10 @@ class TRTLLMProcess:
...
@@ -172,6 +174,10 @@ class TRTLLMProcess:
if
self
.
store_backend
==
"file"
and
"DYN_FILE_KV"
in
os
.
environ
:
if
self
.
store_backend
==
"file"
and
"DYN_FILE_KV"
in
os
.
environ
:
env_vars
[
"DYN_FILE_KV"
]
=
os
.
environ
[
"DYN_FILE_KV"
]
env_vars
[
"DYN_FILE_KV"
]
=
os
.
environ
[
"DYN_FILE_KV"
]
# Enable local indexer for NATS Core mode
if
enable_local_indexer
:
env_vars
[
"DYN_LOCAL_INDEXER"
]
=
"true"
env
.
update
(
env_vars
)
env
.
update
(
env_vars
)
# Create managed process for the worker
# Create managed process for the worker
...
@@ -286,7 +292,7 @@ class TRTLLMProcess:
...
@@ -286,7 +292,7 @@ class TRTLLMProcess:
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
timeout
(
150
)
# ~3x average (~45s/test), rounded up
@
pytest
.
mark
.
timeout
(
150
)
# ~3x average (~45s/test), rounded up
def
test_trtllm_kv_router_basic
(
def
test_trtllm_kv_router_basic
(
request
,
request
,
...
@@ -340,7 +346,7 @@ def test_trtllm_kv_router_basic(
...
@@ -340,7 +346,7 @@ def test_trtllm_kv_router_basic(
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
timeout
(
150
)
# ~3x average (~45s/test), rounded up
@
pytest
.
mark
.
timeout
(
150
)
# ~3x average (~45s/test), rounded up
def
test_router_decisions_trtllm_multiple_workers
(
def
test_router_decisions_trtllm_multiple_workers
(
request
,
request
,
...
@@ -398,10 +404,10 @@ def test_router_decisions_trtllm_multiple_workers(
...
@@ -398,10 +404,10 @@ def test_router_decisions_trtllm_multiple_workers(
"store_backend,use_nats_core,request_plane"
,
"store_backend,use_nats_core,request_plane"
,
[
[
(
"etcd"
,
False
,
"nats"
),
# JetStream mode
(
"etcd"
,
False
,
"nats"
),
# JetStream mode
# ("etcd", True, "tcp"), #
ignored, needs unconditional nats_client
# ("etcd", True, "tcp"), #
nats_core mode - disabled for now
# ("file", False, "nats"), # File backend - TODO: investigate file backend support for TRT-LLM
# ("file", False, "nats"), # File backend - TODO: investigate file backend support for TRT-LLM
],
],
ids
=
[
"jetstream"
],
# "nats_core" and "file" commented out
ids
=
[
"jetstream"
],
)
)
def
test_trtllm_indexers_sync
(
def
test_trtllm_indexers_sync
(
request
,
request
,
...
@@ -419,8 +425,11 @@ def test_trtllm_indexers_sync(
...
@@ -419,8 +425,11 @@ def test_trtllm_indexers_sync(
Tests with configuration:
Tests with configuration:
- jetstream: etcd backend, JetStream for KV events, NATS request plane
- jetstream: etcd backend, JetStream for KV events, NATS request plane
- tcp_nats_core: etcd backend, local indexer with NATS Core, TCP request plane
"""
"""
# runtime_services_dynamic_ports handles NATS and etcd startup
# runtime_services_dynamic_ports handles NATS and etcd startup
nats_process
,
_etcd_process
=
runtime_services_dynamic_ports
logger
.
info
(
logger
.
info
(
f
"Starting TRT-LLM indexers sync test: store_backend=
{
store_backend
}
, "
f
"Starting TRT-LLM indexers sync test: store_backend=
{
store_backend
}
, "
f
"use_nats_core=
{
use_nats_core
}
, request_plane=
{
request_plane
}
"
f
"use_nats_core=
{
use_nats_core
}
, request_plane=
{
request_plane
}
"
...
@@ -438,6 +447,7 @@ def test_trtllm_indexers_sync(
...
@@ -438,6 +447,7 @@ def test_trtllm_indexers_sync(
single_gpu
=
True
,
# fit workers into one GPU
single_gpu
=
True
,
# fit workers into one GPU
request_plane
=
request_plane
,
request_plane
=
request_plane
,
store_backend
=
store_backend
,
store_backend
=
store_backend
,
enable_local_indexer
=
use_nats_core
,
)
)
logger
.
info
(
f
"All TRT-LLM workers using namespace:
{
trtllm_workers
.
namespace
}
"
)
logger
.
info
(
f
"All TRT-LLM workers using namespace:
{
trtllm_workers
.
namespace
}
"
)
trtllm_workers
.
__enter__
()
trtllm_workers
.
__enter__
()
...
@@ -451,6 +461,8 @@ def test_trtllm_indexers_sync(
...
@@ -451,6 +461,8 @@ def test_trtllm_indexers_sync(
num_workers
=
N_TRTLLM_WORKERS
,
num_workers
=
N_TRTLLM_WORKERS
,
store_backend
=
store_backend
,
store_backend
=
store_backend
,
request_plane
=
request_plane
,
request_plane
=
request_plane
,
test_nats_interruption
=
use_nats_core
,
nats_server
=
nats_process
if
use_nats_core
else
None
,
)
)
logger
.
info
(
"TRT-LLM indexers sync test completed successfully"
)
logger
.
info
(
"TRT-LLM indexers sync test completed successfully"
)
...
...
tests/router/test_router_e2e_with_vllm.py
View file @
584020f4
...
@@ -87,6 +87,7 @@ class VLLMProcess:
...
@@ -87,6 +87,7 @@ class VLLMProcess:
data_parallel_size
:
Optional
[
int
]
=
None
,
data_parallel_size
:
Optional
[
int
]
=
None
,
request_plane
:
str
=
"tcp"
,
request_plane
:
str
=
"tcp"
,
store_backend
:
str
=
"etcd"
,
store_backend
:
str
=
"etcd"
,
enable_local_indexer
:
bool
=
False
,
):
):
"""Initialize vLLM workers with dynamo integration.
"""Initialize vLLM workers with dynamo integration.
...
@@ -104,6 +105,7 @@ class VLLMProcess:
...
@@ -104,6 +105,7 @@ class VLLMProcess:
data_parallel_size: If set, enables data parallelism with this many ranks (num_workers must equal data_parallel_size)
data_parallel_size: If set, enables data parallelism with this many ranks (num_workers must equal data_parallel_size)
request_plane: Request plane to use ("nats", "tcp", or "http"). Defaults to "tcp".
request_plane: Request plane to use ("nats", "tcp", or "http"). Defaults to "tcp".
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
enable_local_indexer: If True, enable worker-local KV indexer for NATS Core mode. Defaults to False.
"""
"""
# Generate unique namespace for isolation
# Generate unique namespace for isolation
namespace_suffix
=
generate_random_suffix
()
namespace_suffix
=
generate_random_suffix
()
...
@@ -209,6 +211,10 @@ class VLLMProcess:
...
@@ -209,6 +211,10 @@ class VLLMProcess:
if
self
.
store_backend
==
"file"
and
"DYN_FILE_KV"
in
os
.
environ
:
if
self
.
store_backend
==
"file"
and
"DYN_FILE_KV"
in
os
.
environ
:
env_vars
[
"DYN_FILE_KV"
]
=
os
.
environ
[
"DYN_FILE_KV"
]
env_vars
[
"DYN_FILE_KV"
]
=
os
.
environ
[
"DYN_FILE_KV"
]
# Enable local indexer for NATS Core mode
if
enable_local_indexer
:
env_vars
[
"DYN_LOCAL_INDEXER"
]
=
"true"
env
.
update
(
env_vars
)
env
.
update
(
env_vars
)
# Create managed process for the worker
# Create managed process for the worker
...
@@ -329,7 +335,7 @@ class VLLMProcess:
...
@@ -329,7 +335,7 @@ class VLLMProcess:
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
timeout
(
150
)
# ~3x average (~43s/test), rounded up
@
pytest
.
mark
.
timeout
(
150
)
# ~3x average (~43s/test), rounded up
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"tcp"
],
indirect
=
True
)
def
test_vllm_kv_router_basic
(
def
test_vllm_kv_router_basic
(
request
,
request
,
runtime_services_dynamic_ports
,
runtime_services_dynamic_ports
,
...
@@ -383,7 +389,7 @@ def test_vllm_kv_router_basic(
...
@@ -383,7 +389,7 @@ def test_vllm_kv_router_basic(
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
pre_merge
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
gpu_1
@
pytest
.
mark
.
timeout
(
150
)
# ~3x average (~43s/test), rounded up
@
pytest
.
mark
.
timeout
(
150
)
# ~3x average (~43s/test), rounded up
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"tcp"
],
indirect
=
True
)
def
test_router_decisions_vllm_multiple_workers
(
def
test_router_decisions_vllm_multiple_workers
(
request
,
request
,
runtime_services_dynamic_ports
,
runtime_services_dynamic_ports
,
...
@@ -428,7 +434,7 @@ def test_router_decisions_vllm_multiple_workers(
...
@@ -428,7 +434,7 @@ def test_router_decisions_vllm_multiple_workers(
@
pytest
.
mark
.
gpu_2
@
pytest
.
mark
.
gpu_2
@
pytest
.
mark
.
nightly
@
pytest
.
mark
.
nightly
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"nats"
,
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"request_plane"
,
[
"tcp"
],
indirect
=
True
)
@
pytest
.
mark
.
timeout
(
600
)
# 10 min max (multi-GPU + DP startup variance)
@
pytest
.
mark
.
timeout
(
600
)
# 10 min max (multi-GPU + DP startup variance)
def
test_router_decisions_vllm_dp
(
def
test_router_decisions_vllm_dp
(
request
,
request
,
...
@@ -484,10 +490,10 @@ def test_router_decisions_vllm_dp(
...
@@ -484,10 +490,10 @@ def test_router_decisions_vllm_dp(
"store_backend,use_nats_core,request_plane"
,
"store_backend,use_nats_core,request_plane"
,
[
[
(
"etcd"
,
False
,
"nats"
),
# JetStream mode
(
"etcd"
,
False
,
"nats"
),
# JetStream mode
(
"etcd"
,
True
,
"tcp"
),
# nats_core mode
#
("etcd", True, "tcp"), # nats_core mode
- disabled for now
# ("file", False, "nats"), # File backend
# ("file", False, "nats"), # File backend
],
],
ids
=
[
"jetstream"
,
"tcp_nats_core"
],
ids
=
[
"jetstream"
],
)
)
def
test_vllm_indexers_sync
(
def
test_vllm_indexers_sync
(
request
,
request
,
...
@@ -508,6 +514,8 @@ def test_vllm_indexers_sync(
...
@@ -508,6 +514,8 @@ def test_vllm_indexers_sync(
- tcp_nats_core: etcd backend, local indexer with NATS Core, TCP request plane
- tcp_nats_core: etcd backend, local indexer with NATS Core, TCP request plane
"""
"""
# runtime_services_dynamic_ports handles NATS and etcd startup
# runtime_services_dynamic_ports handles NATS and etcd startup
nats_process
,
_etcd_process
=
runtime_services_dynamic_ports
logger
.
info
(
logger
.
info
(
f
"Starting vLLM indexers sync test: store_backend=
{
store_backend
}
, "
f
"Starting vLLM indexers sync test: store_backend=
{
store_backend
}
, "
f
"use_nats_core=
{
use_nats_core
}
, request_plane=
{
request_plane
}
"
f
"use_nats_core=
{
use_nats_core
}
, request_plane=
{
request_plane
}
"
...
@@ -525,6 +533,7 @@ def test_vllm_indexers_sync(
...
@@ -525,6 +533,7 @@ def test_vllm_indexers_sync(
single_gpu
=
True
,
# fit workers into one GPU
single_gpu
=
True
,
# fit workers into one GPU
request_plane
=
request_plane
,
request_plane
=
request_plane
,
store_backend
=
store_backend
,
store_backend
=
store_backend
,
enable_local_indexer
=
use_nats_core
,
)
)
logger
.
info
(
f
"All vLLM workers using namespace:
{
vllm_workers
.
namespace
}
"
)
logger
.
info
(
f
"All vLLM workers using namespace:
{
vllm_workers
.
namespace
}
"
)
vllm_workers
.
__enter__
()
vllm_workers
.
__enter__
()
...
@@ -538,6 +547,8 @@ def test_vllm_indexers_sync(
...
@@ -538,6 +547,8 @@ def test_vllm_indexers_sync(
num_workers
=
N_VLLM_WORKERS
,
num_workers
=
N_VLLM_WORKERS
,
store_backend
=
store_backend
,
store_backend
=
store_backend
,
request_plane
=
request_plane
,
request_plane
=
request_plane
,
test_nats_interruption
=
use_nats_core
,
nats_server
=
nats_process
if
use_nats_core
else
None
,
)
)
logger
.
info
(
"vLLM indexers sync test completed successfully"
)
logger
.
info
(
"vLLM indexers sync test completed successfully"
)
...
...
tests/utils/managed_process.py
View file @
584020f4
...
@@ -35,7 +35,7 @@ def terminate_process(process, logger=logging.getLogger(), immediate_kill=False)
...
@@ -35,7 +35,7 @@ def terminate_process(process, logger=logging.getLogger(), immediate_kill=False)
def
terminate_process_tree
(
def
terminate_process_tree
(
pid
,
logger
=
logging
.
getLogger
(),
immediate_kill
=
False
,
timeout
=
10
pid
,
logger
=
logging
.
getLogger
(),
immediate_kill
=
False
,
timeout
=
2
):
):
try
:
try
:
parent
=
psutil
.
Process
(
pid
)
parent
=
psutil
.
Process
(
pid
)
...
@@ -277,7 +277,7 @@ class ManagedProcess:
...
@@ -277,7 +277,7 @@ class ManagedProcess:
)
)
self
.
_tee_proc
=
None
self
.
_tee_proc
=
None
def
_terminate_process_group
(
self
,
timeout
:
float
=
5
.0
):
def
_terminate_process_group
(
self
,
timeout
:
float
=
2
.0
):
"""Terminate the entire process group/session started for the child.
"""Terminate the entire process group/session started for the child.
This catches cases where the launcher shell exits and its children are reparented,
This catches cases where the launcher shell exits and its children are reparented,
...
@@ -296,10 +296,23 @@ class ManagedProcess:
...
@@ -296,10 +296,23 @@ class ManagedProcess:
)
)
return
return
# Give processes a brief moment to exit gracefully
# Poll for process exit instead of fixed sleep to minimize teardown time
time
.
sleep
(
timeout
)
poll_interval
=
0.1
elapsed
=
0.0
# Force kill if anything remains
while
elapsed
<
timeout
:
try
:
# Check if any process in the group is still alive
os
.
killpg
(
self
.
_pgid
,
0
)
# Signal 0 = check existence
except
ProcessLookupError
:
# Process group no longer exists - done
return
except
Exception
:
# Other errors (e.g., permission) - assume done
return
time
.
sleep
(
poll_interval
)
elapsed
+=
poll_interval
# Force kill if anything remains after timeout
try
:
try
:
os
.
killpg
(
self
.
_pgid
,
signal
.
SIGKILL
)
os
.
killpg
(
self
.
_pgid
,
signal
.
SIGKILL
)
except
ProcessLookupError
:
except
ProcessLookupError
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment