Unverified Commit 85d83108 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: router-level request rejection (#2465)

parent 1945f599
......@@ -143,6 +143,12 @@ def parse_args():
default=False,
help="KV Router: Enable replica synchronization across multiple router instances. When true, routers will publish and subscribe to events to maintain consistent state.",
)
parser.add_argument(
"--busy-threshold",
type=float,
default=None,
help="Threshold (0.0-1.0) for determining when a worker is considered busy based on KV cache usage. If not set, busy detection is disabled.",
)
parser.add_argument(
"--static-endpoint",
type=validate_static_endpoint,
......@@ -205,7 +211,9 @@ async def async_main():
kwargs = {
"http_port": flags.http_port,
"kv_cache_block_size": flags.kv_cache_block_size,
"router_config": RouterConfig(router_mode, kv_router_config),
"router_config": RouterConfig(
router_mode, kv_router_config, flags.busy_threshold
),
}
if flags.static_endpoint:
......
......@@ -60,16 +60,22 @@ impl KvRouterConfig {
pub struct RouterConfig {
router_mode: RouterMode,
kv_router_config: KvRouterConfig,
busy_threshold: Option<f64>,
}
#[pymethods]
impl RouterConfig {
#[new]
#[pyo3(signature = (mode, config=None))]
pub fn new(mode: RouterMode, config: Option<KvRouterConfig>) -> Self {
#[pyo3(signature = (mode, config=None, busy_threshold=None))]
pub fn new(
mode: RouterMode,
config: Option<KvRouterConfig>,
busy_threshold: Option<f64>,
) -> Self {
Self {
router_mode: mode,
kv_router_config: config.unwrap_or_default(),
busy_threshold,
}
}
}
......@@ -79,6 +85,7 @@ impl From<RouterConfig> for RsRouterConfig {
RsRouterConfig {
router_mode: rc.router_mode.into(),
kv_router_config: rc.kv_router_config.inner,
busy_threshold: rc.busy_threshold,
}
}
}
......
......@@ -50,6 +50,7 @@ pub struct ModelWatcher {
notify_on_model: Notify,
model_update_tx: Option<Sender<ModelUpdate>>,
kv_router_config: Option<KvRouterConfig>,
busy_threshold: Option<f64>,
}
const ALL_MODEL_TYPES: &[ModelType] =
......@@ -61,6 +62,7 @@ impl ModelWatcher {
model_manager: Arc<ModelManager>,
router_mode: RouterMode,
kv_router_config: Option<KvRouterConfig>,
busy_threshold: Option<f64>,
) -> ModelWatcher {
Self {
manager: model_manager,
......@@ -69,6 +71,7 @@ impl ModelWatcher {
notify_on_model: Notify::new(),
model_update_tx: None,
kv_router_config,
busy_threshold,
}
}
......@@ -316,20 +319,30 @@ impl ModelWatcher {
None
};
let chat_engine =
entrypoint::build_routed_pipeline::<
let chat_engine = entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(&card, &client, self.router_mode, kv_chooser.clone())
>(
&card,
&client,
self.router_mode,
self.busy_threshold,
kv_chooser.clone(),
)
.await?;
self.manager
.add_chat_completions_model(&model_entry.name, chat_engine)?;
let completions_engine =
entrypoint::build_routed_pipeline::<
let completions_engine = entrypoint::build_routed_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(&card, &client, self.router_mode, kv_chooser)
>(
&card,
&client,
self.router_mode,
self.busy_threshold,
kv_chooser,
)
.await?;
self.manager
.add_completions_model(&model_entry.name, completions_engine)?;
......@@ -338,7 +351,9 @@ impl ModelWatcher {
let push_router = PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client(client, Default::default())
>::from_client_with_threshold(
client, Default::default(), self.busy_threshold
)
.await?;
let engine = Arc::new(push_router);
self.manager
......@@ -348,7 +363,9 @@ impl ModelWatcher {
let push_router = PushRouter::<
NvCreateCompletionRequest,
Annotated<NvCreateCompletionResponse>,
>::from_client(client, Default::default())
>::from_client_with_threshold(
client, Default::default(), self.busy_threshold
)
.await?;
let engine = Arc::new(push_router);
self.manager
......@@ -374,7 +391,9 @@ impl ModelWatcher {
let router = PushRouter::<
PreprocessedEmbeddingRequest,
Annotated<EmbeddingsEngineOutput>,
>::from_client(client, self.router_mode)
>::from_client_with_threshold(
client, self.router_mode, self.busy_threshold
)
.await?;
// Note: Embeddings don't need KV routing complexity
......
......@@ -21,6 +21,7 @@ use crate::{
pub struct RouterConfig {
pub router_mode: RouterMode,
pub kv_router_config: KvRouterConfig,
pub busy_threshold: Option<f64>,
}
impl RouterConfig {
......@@ -28,8 +29,14 @@ impl RouterConfig {
Self {
router_mode,
kv_router_config,
busy_threshold: None,
}
}
pub fn with_busy_threshold(mut self, threshold: Option<f64>) -> Self {
self.busy_threshold = threshold;
self
}
}
#[derive(Clone)]
......
......@@ -71,6 +71,7 @@ pub async fn prepare_engine(
model_manager.clone(),
dynamo_runtime::pipeline::RouterMode::RoundRobin,
None,
None,
));
let models_watcher = etcd_client.kv_get_and_watch_prefix(MODEL_ROOT_PATH).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
......@@ -133,7 +134,7 @@ pub async fn prepare_engine(
let chat_engine = entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(card, &client, router_mode, kv_chooser.clone())
>(card, &client, router_mode, None, kv_chooser.clone())
.await?;
let service_name = local_model.service_name().to_string();
......@@ -216,6 +217,7 @@ pub async fn build_routed_pipeline<Req, Resp>(
card: &ModelDeploymentCard,
client: &Client,
router_mode: RouterMode,
busy_threshold: Option<f64>,
chooser: Option<Arc<KvRouter>>,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where
......@@ -232,9 +234,11 @@ where
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let migration = Migration::from_mdc(card.clone()).await?.into_operator();
let router = PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client(
let router =
PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client.clone(),
router_mode,
busy_threshold,
)
.await?;
let service_backend = match router_mode {
......
......@@ -66,6 +66,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
MODEL_ROOT_PATH,
router_config.router_mode,
Some(router_config.kv_router_config),
router_config.busy_threshold,
Arc::new(http_service.clone()),
)
.await?;
......@@ -109,14 +110,14 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let chat_engine = entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(card, &client, router_mode, kv_chooser.clone())
>(card, &client, router_mode, None, kv_chooser.clone())
.await?;
manager.add_chat_completions_model(local_model.display_name(), chat_engine)?;
let completions_engine = entrypoint::build_routed_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(card, &client, router_mode, kv_chooser)
>(card, &client, router_mode, None, kv_chooser)
.await?;
manager.add_completions_model(local_model.display_name(), completions_engine)?;
......@@ -188,6 +189,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
/// Spawns a task that watches for new models in etcd at network_prefix,
/// and registers them with the ModelManager so that the HTTP service can use them.
#[allow(clippy::too_many_arguments)]
async fn run_watcher(
runtime: DistributedRuntime,
model_manager: Arc<ModelManager>,
......@@ -195,9 +197,16 @@ async fn run_watcher(
network_prefix: &str,
router_mode: RouterMode,
kv_router_config: Option<KvRouterConfig>,
busy_threshold: Option<f64>,
http_service: Arc<HttpService>,
) -> anyhow::Result<()> {
let mut watch_obj = ModelWatcher::new(runtime, model_manager, router_mode, kv_router_config);
let mut watch_obj = ModelWatcher::new(
runtime,
model_manager,
router_mode,
kv_router_config,
busy_threshold,
);
tracing::info!("Watching for remote model at {network_prefix}");
let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
......
......@@ -108,6 +108,24 @@ impl ErrorMessage {
/// If successful, it will return the [`HttpError`] as an [`ErrorMessage::internal_server_error`]
/// with the details of the error.
pub fn from_anyhow(err: anyhow::Error, alt_msg: &str) -> ErrorResponse {
// First check for PipelineError::ServiceOverloaded
if let Some(pipeline_err) =
err.downcast_ref::<dynamo_runtime::pipeline::error::PipelineError>()
{
if matches!(
pipeline_err,
dynamo_runtime::pipeline::error::PipelineError::ServiceOverloaded(_)
) {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorMessage {
error: pipeline_err.to_string(),
}),
);
}
}
// Then check for HttpError
match err.downcast::<HttpError>() {
Ok(http_error) => ErrorMessage::from_http_error(http_error),
Err(err) => ErrorMessage::internal_server_error(&format!("{alt_msg}: {err}")),
......@@ -1150,6 +1168,22 @@ mod tests {
);
}
#[test]
fn test_service_overloaded_error_response_from_anyhow() {
use dynamo_runtime::pipeline::error::PipelineError;
let err: anyhow::Error = PipelineError::ServiceOverloaded(
"All workers are busy, please retry later".to_string(),
)
.into();
let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE);
assert_eq!(
response.error,
"Service temporarily unavailable: All workers are busy, please retry later"
);
}
#[test]
fn test_validate_input_is_text_only_accepts_text() {
let request = make_base_request();
......
......@@ -29,13 +29,13 @@ pub mod scoring;
pub mod sequence;
use crate::{
discovery::{ModelEntry, MODEL_ROOT_PATH},
kv_router::{
approx::ApproxKvIndexer,
indexer::{
compute_block_hash_for_seq, compute_seq_hash_for_block, KvIndexer, KvIndexerInterface,
KvRouterError, OverlapScores, RouterEvent,
},
metrics_aggregator::watch_model_runtime_configs,
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
scoring::ProcessedEndpoints,
......@@ -177,14 +177,25 @@ impl KvRouter {
}
};
// Create runtime config watcher
// Create runtime config watcher using the generic etcd watcher
// TODO: Migrate to discovery_client() once it exposes kv_get_and_watch_prefix functionality
let etcd_client = component
.drt()
.etcd_client()
.expect("Cannot KV route without etcd client");
let runtime_configs_rx =
watch_model_runtime_configs(etcd_client, cancellation_token.clone()).await?;
use dynamo_runtime::utils::typed_prefix_watcher::{
key_extractors, watch_prefix_with_extraction,
};
let runtime_configs_watcher = watch_prefix_with_extraction(
etcd_client,
MODEL_ROOT_PATH,
key_extractors::lease_id,
|model_entry: ModelEntry| model_entry.runtime_config,
cancellation_token.clone(),
)
.await?;
let runtime_configs_rx = runtime_configs_watcher.receiver();
let indexer = if kv_router_config.use_kv_events {
Indexer::KvIndexer(KvIndexer::new(cancellation_token.clone(), block_size))
......
......@@ -18,14 +18,10 @@ use std::sync::Once;
pub use crate::kv_router::protocols::{ForwardPassMetrics, LoadMetrics, PredictiveLoadMetrics};
use crate::kv_router::KV_METRICS_ENDPOINT;
use crate::discovery::{ModelEntry, MODEL_ROOT_PATH};
use crate::kv_router::scoring::Endpoint;
use crate::kv_router::ProcessedEndpoints;
use crate::local_model::runtime_config::ModelRuntimeConfig;
use dynamo_runtime::component::Component;
use dynamo_runtime::transports::etcd::{Client as EtcdClient, WatchEvent};
use dynamo_runtime::{service::EndpointInfo, utils::Duration, Result};
use std::collections::HashMap;
use tokio::sync::watch;
use tokio_util::sync::CancellationToken;
......@@ -212,71 +208,3 @@ pub async fn collect_endpoints_task(
}
}
}
pub async fn watch_model_runtime_configs(
etcd_client: EtcdClient,
cancellation_token: CancellationToken,
) -> Result<watch::Receiver<HashMap<i64, ModelRuntimeConfig>>> {
let (watch_tx, watch_rx) = watch::channel(HashMap::new());
let prefix_watcher = etcd_client.kv_get_and_watch_prefix(MODEL_ROOT_PATH).await?;
let (_prefix, _watcher, mut events_rx) = prefix_watcher.dissolve();
tokio::spawn(async move {
let mut runtime_configs: HashMap<i64, ModelRuntimeConfig> = HashMap::new();
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::debug!("Runtime config watcher cancelled");
break;
}
event = events_rx.recv() => {
let Some(event) = event else {
tracing::debug!("Runtime config watch stream closed");
break;
};
match event {
WatchEvent::Put(kv) => {
let Ok(model_entry) = serde_json::from_slice::<ModelEntry>(kv.value()) else {
tracing::warn!(
"Failed to parse ModelEntry from etcd. Key: {}",
kv.key_str().unwrap_or("<invalid>")
);
continue;
};
let lease_id = kv.lease();
if let Some(runtime_config) = model_entry.runtime_config {
runtime_configs.insert(lease_id, runtime_config);
tracing::trace!("Updated runtime config for lease_id: {}", lease_id);
} else {
runtime_configs.remove(&lease_id);
tracing::trace!("Removed runtime config (no config in ModelEntry)");
}
if watch_tx.send(runtime_configs.clone()).is_err() {
tracing::error!("Failed to send runtime configs update; receiver dropped");
break;
}
}
WatchEvent::Delete(kv) => {
let lease_id = kv.lease();
runtime_configs.remove(&lease_id);
tracing::trace!("Removed runtime config for deleted entry");
if watch_tx.send(runtime_configs.clone()).is_err() {
tracing::error!("Failed to send runtime configs update; receiver dropped");
break;
}
}
}
}
}
}
});
Ok(watch_rx)
}
......@@ -503,14 +503,16 @@ impl WorkerMetricsPublisher {
let handler = Arc::new(KvLoadEndpointHandler::new(metrics_rx.clone()));
let handler = Ingress::for_engine(handler)?;
// let worker_id = component
// .drt()
// .primary_lease()
// .map(|lease| lease.id())
// .unwrap_or_else(|| {
// tracing::warn!("Component is static, assuming worker_id of 0");
// 0
// });
let worker_id = component
.drt()
.primary_lease()
.map(|lease| lease.id())
.unwrap_or_else(|| {
tracing::warn!("Component is static, assuming worker_id of 0");
0
});
self.start_nats_metrics_publishing(component.namespace().clone(), worker_id);
component
.endpoint(KV_METRICS_ENDPOINT)
......
......@@ -42,8 +42,8 @@ use futures::StreamExt;
use rand::Rng;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, Mutex, OnceCell};
use tokio::time::{interval, Duration};
use tokio_stream::wrappers::ReceiverStream;
use uuid::Uuid;
......@@ -174,7 +174,7 @@ impl MockVllmEngine {
(schedulers, kv_event_receivers)
}
/// Start background tasks to poll and publish metrics every second
/// Start background tasks to publish metrics on change
async fn start_metrics_publishing(
schedulers: &[Scheduler],
component: Option<Component>,
......@@ -202,19 +202,18 @@ impl MockVllmEngine {
tracing::info!("Starting metrics background tasks");
for (dp_rank, scheduler) in schedulers.iter().enumerate() {
let scheduler = scheduler.clone();
let mut metrics_rx = scheduler.metrics_receiver();
let publisher = metrics_publisher.clone();
let dp_rank = dp_rank as u32;
let cancel_token = cancel_token.clone();
tokio::spawn(async move {
let mut interval = interval(Duration::from_millis(100));
loop {
tokio::select! {
_ = interval.tick() => {
// Get metrics from scheduler
let metrics = scheduler.get_forward_pass_metrics().await;
// Watch for metrics changes
Ok(_) = metrics_rx.changed() => {
// Get the latest metrics
let metrics = metrics_rx.borrow().clone();
// Publish metrics
if let Err(e) = publisher.publish(Arc::new(metrics)) {
......@@ -568,7 +567,7 @@ mod integration_tests {
let engine = MockVllmEngine::new(args);
engine.start(test_component.clone()).await?;
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
tokio::time::sleep(Duration::from_millis(500)).await;
let engine = Arc::new(engine);
tracing::info!("✓ MockVllmEngine created with DP_SIZE: {DP_SIZE}");
......@@ -598,7 +597,7 @@ mod integration_tests {
tracing::info!("✓ Server started in background");
// Give server time to start
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
tokio::time::sleep(Duration::from_millis(500)).await;
tracing::info!("✓ Server startup delay completed");
// Print all registered instances from etcd
......@@ -733,7 +732,7 @@ mod integration_tests {
cancel_token,
)
.await;
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
tokio::time::sleep(Duration::from_millis(500)).await;
let processed_endpoints = metrics_aggregator.get_endpoints();
tracing::info!(
......
......@@ -250,11 +250,10 @@ impl SchedulerState {
/// Manages scheduling of requests using KvManager resources
#[derive(Clone)]
pub struct Scheduler {
dp_rank: Option<u32>,
state: Arc<Mutex<SchedulerState>>,
kv_manager: Arc<Mutex<KvManager>>,
request_tx: mpsc::UnboundedSender<DirectRequest>,
hit_rates: Arc<Mutex<VecDeque<f32>>>,
metrics_rx: tokio::sync::watch::Receiver<ForwardPassMetrics>,
}
impl Scheduler {
......@@ -292,13 +291,16 @@ impl Scheduler {
// Create channel for request handling
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
let mut initial_metrics = ForwardPassMetrics::default();
initial_metrics.worker_stats.data_parallel_rank = dp_rank;
let (metrics_tx, metrics_rx) =
tokio::sync::watch::channel::<ForwardPassMetrics>(initial_metrics);
// Create a clone for the background task
let state_clone = state.clone();
let kv_manager_clone = kv_manager.clone();
let output_tx_clone = output_tx.clone();
let cancel_token_clone = cancellation_token.unwrap_or_default().clone();
let hit_rates_clone = hit_rates.clone();
// Spawn main background task with cancellation token
tokio::spawn(async move {
......@@ -376,7 +378,7 @@ impl Scheduler {
// Compute and store hit rate
let hit_rate = if !active_sequence.is_empty() { 1.0 - (new_tokens as f32 / active_sequence.len() as f32) } else { 0.0 };
{
let mut hit_rates_guard = hit_rates_clone.lock().await;
let mut hit_rates_guard = hit_rates.lock().await;
hit_rates_guard.push_back(hit_rate);
if hit_rates_guard.len() > 1000 {
hit_rates_guard.pop_front();
......@@ -442,6 +444,17 @@ impl Scheduler {
state_guard.reset_active_tokens();
{
let hit_rates_guard = hit_rates.lock().await;
let metrics = get_fwd_pass_metrics(
&state_guard,
&kv_manager_guard,
&hit_rates_guard,
dp_rank,
);
let _ = metrics_tx.send(metrics);
}
// Process decoding
let uuids: Vec<Uuid> = state_guard.decode.keys().cloned().collect();
if !uuids.is_empty() {
......@@ -495,6 +508,17 @@ impl Scheduler {
}
}
{
let hit_rates_guard = hit_rates.lock().await;
let metrics = get_fwd_pass_metrics(
&state_guard,
&kv_manager_guard,
&hit_rates_guard,
dp_rank,
);
let _ = metrics_tx.send(metrics);
}
if send_failed || is_complete {
state_guard.complete(&uuid);
continue;
......@@ -513,11 +537,10 @@ impl Scheduler {
});
Self {
dp_rank,
state,
kv_manager,
request_tx,
hit_rates,
metrics_rx,
}
}
......@@ -555,13 +578,19 @@ impl Scheduler {
kv_manager.current_capacity_perc()
}
/// Returns forward pass metrics for monitoring purposes
pub async fn get_forward_pass_metrics(&self) -> ForwardPassMetrics {
// Acquire all locks in consistent order: state -> kv_manager -> hit_rates
let state = self.state.lock().await;
let kv_manager = self.kv_manager.lock().await;
let hit_rates_guard = self.hit_rates.lock().await;
/// Get a watch receiver for forward pass metrics
pub fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<ForwardPassMetrics> {
self.metrics_rx.clone()
}
}
/// Calculate forward pass metrics from current state
fn get_fwd_pass_metrics(
state: &SchedulerState,
kv_manager: &KvManager,
hit_rates: &VecDeque<f32>,
dp_rank: Option<u32>,
) -> ForwardPassMetrics {
// Get state metrics
let request_active_slots = state.decode.len() as u64;
let num_requests_waiting = state.waiting.len() as u64;
......@@ -576,15 +605,15 @@ impl Scheduler {
};
// Get hit rate metrics
let gpu_prefix_cache_hit_rate = if hit_rates_guard.is_empty() {
let gpu_prefix_cache_hit_rate = if hit_rates.is_empty() {
0.0
} else {
let sum: f32 = hit_rates_guard.iter().sum();
sum / hit_rates_guard.len() as f32
let sum: f32 = hit_rates.iter().sum();
sum / hit_rates.len() as f32
};
let worker_stats = WorkerStats {
data_parallel_rank: self.dp_rank,
data_parallel_rank: dp_rank,
request_active_slots,
request_total_slots: 1024, // vllm max_num_seqs for gpu >= 70 vram, otherwise 256, fallback is 128
num_requests_waiting,
......@@ -604,8 +633,6 @@ impl Scheduler {
kv_stats,
spec_decode_stats,
}
// Guards drop naturally here in reverse order (LIFO): hit_rates_guard, kv_manager, state
}
}
/// Convert a Request to an ActiveSequence
......@@ -761,6 +788,9 @@ mod tests {
let timeout = tokio::time::sleep(Duration::from_secs(2));
tokio::pin!(timeout);
// Get metrics receiver
let metrics_rx = scheduler.metrics_receiver();
// Set up debug ticker interval
let mut debug_interval = interval(Duration::from_millis(500));
......@@ -770,7 +800,7 @@ mod tests {
// Manual debug ticker that prints forward pass metrics
_ = debug_interval.tick() => {
let _metrics = scheduler.get_forward_pass_metrics().await;
let _metrics = metrics_rx.borrow().clone();
println!("Forward Pass Metrics: {_metrics:#?}");
}
......@@ -862,6 +892,9 @@ mod tests {
let timeout = tokio::time::sleep(Duration::from_millis(500));
tokio::pin!(timeout);
// Get metrics receiver
let metrics_rx = scheduler.metrics_receiver();
// Set up debug ticker interval
let mut debug_interval = interval(Duration::from_millis(500));
......@@ -871,7 +904,7 @@ mod tests {
// Manual debug ticker that prints forward pass metrics
_ = debug_interval.tick() => {
let _metrics = scheduler.get_forward_pass_metrics().await;
let _metrics = metrics_rx.borrow().clone();
println!("Forward Pass Metrics: {_metrics:#?}");
}
......@@ -888,8 +921,11 @@ mod tests {
}
}
// Wait a bit for final metrics update
tokio::time::sleep(Duration::from_millis(100)).await;
// Verify forward pass metrics
let metrics = scheduler.get_forward_pass_metrics().await;
let metrics = metrics_rx.borrow().clone();
assert_eq!(
metrics.worker_stats.num_requests_waiting, 0,
......@@ -958,7 +994,8 @@ mod tests {
tokio::time::sleep(Duration::from_secs(1)).await;
// Check forward pass metrics
let metrics = scheduler.get_forward_pass_metrics().await;
let metrics_rx = scheduler.metrics_receiver();
let metrics = metrics_rx.borrow().clone();
assert_eq!(
metrics.kv_stats.gpu_cache_usage_perc,
......
......@@ -44,6 +44,8 @@ pub struct Client {
pub instance_source: Arc<InstanceSource>,
// These are the instance source ids less those reported as down from sending rpc
instance_avail: Arc<ArcSwap<Vec<i64>>>,
// These are the instance source ids less those reported as busy (above threshold)
instance_free: Arc<ArcSwap<Vec<i64>>>,
}
#[derive(Clone, Debug)]
......@@ -59,6 +61,7 @@ impl Client {
endpoint,
instance_source: Arc::new(InstanceSource::Static),
instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))),
instance_free: Arc::new(ArcSwap::from(Arc::new(vec![]))),
})
}
......@@ -76,8 +79,9 @@ impl Client {
let client = Client {
endpoint,
instance_source,
instance_source: instance_source.clone(),
instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))),
instance_free: Arc::new(ArcSwap::from(Arc::new(vec![]))),
};
client.monitor_instance_source();
Ok(client)
......@@ -108,6 +112,10 @@ impl Client {
self.instance_avail.load()
}
pub fn instance_ids_free(&self) -> arc_swap::Guard<Arc<Vec<i64>>> {
self.instance_free.load()
}
/// Wait for at least one Instance to be available for this Endpoint
pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
let mut instances: Vec<Instance> = vec![];
......@@ -142,6 +150,16 @@ impl Client {
tracing::debug!("inhibiting instance {instance_id}");
}
/// Update the set of free instances based on busy instance IDs
pub fn update_free_instances(&self, busy_instance_ids: &[i64]) {
let all_instance_ids = self.instance_ids();
let free_ids: Vec<i64> = all_instance_ids
.into_iter()
.filter(|id| !busy_instance_ids.contains(id))
.collect();
self.instance_free.store(Arc::new(free_ids));
}
/// Monitor the ETCD instance source and update instance_avail.
fn monitor_instance_source(&self) {
let cancel_token = self.endpoint.drt().primary_token();
......@@ -160,7 +178,10 @@ impl Client {
.iter()
.map(|instance| instance.id())
.collect();
client.instance_avail.store(Arc::new(instance_ids));
// TODO: this resets both tracked available and free instances
client.instance_avail.store(Arc::new(instance_ids.clone()));
client.instance_free.store(Arc::new(instance_ids));
tracing::debug!("instance source updated");
......
......@@ -131,6 +131,10 @@ pub enum PipelineError {
#[error("NATS KV Err: {0} for bucket '{1}")]
KeyValueError(String, String),
/// All instances are busy and cannot handle new requests
#[error("Service temporarily unavailable: {0}")]
ServiceOverloaded(String),
}
#[derive(Debug, thiserror::Error)]
......
......@@ -2,11 +2,13 @@
// SPDX-License-Identifier: Apache-2.0
use super::{AsyncEngineContextProvider, ResponseStream};
use crate::utils::worker_monitor::WorkerMonitor;
use crate::{
component::{Client, Endpoint, InstanceSource},
engine::{AsyncEngine, Data},
pipeline::{
error::PipelineErrorExt, AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
error::{PipelineError, PipelineErrorExt},
AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
},
protocols::maybe_error::MaybeError,
traits::DistributedRuntimeProvider,
......@@ -52,6 +54,13 @@ where
/// addresses it, then passes it to AddressedPushRouter which does the network traffic.
addressed: Arc<AddressedPushRouter>,
/// Worker monitor for tracking KV cache usage
worker_monitor: Option<Arc<WorkerMonitor>>,
/// Threshold for determining when a worker is busy (0.0 to 1.0)
/// If None, busy detection is disabled
busy_threshold: Option<f64>,
/// An internal Rust type. This says that PushRouter is generic over the T and U types,
/// which are the input and output types of it's `generate` function. It allows the
/// compiler to specialize us at compile time.
......@@ -86,15 +95,43 @@ where
T: Data + Serialize,
U: Data + for<'de> Deserialize<'de> + MaybeError,
{
/// Create a new PushRouter without busy threshold (no busy detection)
pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
Self::from_client_with_threshold(client, router_mode, None).await
}
/// Create a new PushRouter with optional busy threshold
pub async fn from_client_with_threshold(
client: Client,
router_mode: RouterMode,
busy_threshold: Option<f64>,
) -> anyhow::Result<Self> {
let addressed = addressed_router(&client.endpoint).await?;
Ok(PushRouter {
client,
// Create worker monitor only if we have a threshold and are in dynamic mode
let worker_monitor = match (busy_threshold, client.instance_source.as_ref()) {
(Some(threshold), InstanceSource::Dynamic(_)) => {
let monitor = Arc::new(WorkerMonitor::new_with_threshold(
Arc::new(client.clone()),
threshold,
));
monitor.start_monitoring().await?;
Some(monitor)
}
_ => None,
};
let router = PushRouter {
client: client.clone(),
addressed,
router_mode,
round_robin_counter: Arc::new(AtomicU64::new(0)),
worker_monitor,
busy_threshold,
_phantom: PhantomData,
})
};
Ok(router)
}
/// Issue a request to the next available instance in a round-robin fashion
......@@ -170,6 +207,21 @@ where
instance_id: i64,
request: SingleIn<T>,
) -> anyhow::Result<ManyOut<U>> {
// Check if all workers are busy (only if busy threshold is set)
if self.busy_threshold.is_some() {
let free_instances = self.client.instance_ids_free();
if free_instances.is_empty() {
// Check if we actually have any instances at all
let all_instances = self.client.instance_ids();
if !all_instances.is_empty() {
return Err(PipelineError::ServiceOverloaded(
"All workers are busy, please retry later".to_string(),
)
.into());
}
}
}
let subject = self.client.endpoint.subject_to(instance_id);
let request = request.map(|req| AddressedRequest::new(req, subject));
......
......@@ -19,3 +19,5 @@ pub mod leader_worker_barrier;
pub mod pool;
pub mod stream;
pub mod task;
pub mod typed_prefix_watcher;
pub mod worker_monitor;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Generic etcd watcher utilities for maintaining collated state from etcd prefixes.
//!
//! This module provides reusable patterns for watching etcd prefixes and maintaining
//! HashMap-based state that automatically updates based on etcd events.
use crate::transports::etcd::{Client as EtcdClient, WatchEvent};
use crate::Result;
use etcd_client::KeyValue;
use serde::de::DeserializeOwned;
use std::collections::HashMap;
use std::fmt::Debug;
use tokio::sync::watch;
use tokio_util::sync::CancellationToken;
/// A generic etcd prefix watcher that maintains a HashMap of deserialized values.
///
/// This struct watches an etcd prefix and maintains a HashMap where:
/// - Keys are extracted from the etcd KeyValue (e.g., lease_id, key string, etc.)
/// - Values are extracted from the deserialized type using a value extractor
///
/// # Type Parameters
/// - `K`: The key type for the HashMap (must be hashable)
/// - `V`: The value type stored in the HashMap
pub struct TypedPrefixWatcher<K, V>
where
K: Clone + Eq + std::hash::Hash + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
rx: watch::Receiver<HashMap<K, V>>,
}
impl<K, V> TypedPrefixWatcher<K, V>
where
K: Clone + Eq + std::hash::Hash + Send + Sync + Debug + 'static,
V: Clone + Send + Sync + 'static,
{
/// Get a receiver for the current state
pub fn receiver(&self) -> watch::Receiver<HashMap<K, V>> {
self.rx.clone()
}
/// Get the current state
pub fn current(&self) -> HashMap<K, V> {
self.rx.borrow().clone()
}
}
/// Watch an etcd prefix and maintain a HashMap of values with field extraction
///
/// This function watches an etcd prefix and maintains a HashMap where values are
/// extracted from a deserialized type using a value extractor function.
///
/// # Type Parameters
/// - `K`: The key type for the HashMap
/// - `V`: The value type stored in the HashMap
/// - `T`: The type to deserialize from etcd
///
/// # Arguments
/// - `client`: The etcd client to use
/// - `prefix`: The prefix to watch in etcd
/// - `key_extractor`: Function to extract the key from a KeyValue
/// - `value_extractor`: Function to extract the value from the deserialized type
/// - `cancellation_token`: Token to stop the watcher
///
/// # Example
/// ```ignore
/// // Watch for ModelEntry objects and extract runtime_config field
/// let watcher = watch_prefix_with_extraction(
/// etcd_client,
/// "models/",
/// |kv| Some(kv.lease()), // Use lease_id as key
/// |entry: ModelEntry| entry.runtime_config, // Extract runtime_config field
/// cancellation_token,
/// ).await?;
/// ```
pub async fn watch_prefix_with_extraction<K, V, T>(
client: EtcdClient,
prefix: impl Into<String>,
key_extractor: impl Fn(&KeyValue) -> Option<K> + Send + 'static,
value_extractor: impl Fn(T) -> Option<V> + Send + 'static,
cancellation_token: CancellationToken,
) -> Result<TypedPrefixWatcher<K, V>>
where
K: Clone + Eq + std::hash::Hash + Send + Sync + Debug + 'static,
V: Clone + Send + Sync + 'static,
T: DeserializeOwned + Send + 'static,
{
let (watch_tx, watch_rx) = watch::channel(HashMap::new());
let prefix = prefix.into();
let prefix_watcher = client.kv_get_and_watch_prefix(&prefix).await?;
let (prefix_str, _watcher, mut events_rx) = prefix_watcher.dissolve();
tokio::spawn(async move {
let mut state: HashMap<K, V> = HashMap::new();
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::debug!("TypedPrefixWatcher for prefix '{}' cancelled", prefix_str);
break;
}
event = events_rx.recv() => {
let Some(event) = event else {
tracing::debug!("TypedPrefixWatcher watch stream closed for prefix '{}'", prefix_str);
break;
};
match event {
WatchEvent::Put(kv) => {
// Extract the key
let Some(key) = key_extractor(&kv) else {
tracing::trace!("Skipping entry - key extractor returned None");
continue;
};
// Deserialize the value
let deserialized = match serde_json::from_slice::<T>(kv.value()) {
Ok(val) => val,
Err(e) => {
tracing::warn!(
"Failed to deserialize value from etcd. Key: {}, Error: {}",
kv.key_str().unwrap_or("<invalid>"),
e
);
continue;
}
};
// Extract the value
match value_extractor(deserialized) {
Some(v) => {
state.insert(key.clone(), v);
tracing::trace!("Updated entry for key {:?}", key);
}
None => {
state.remove(&key);
tracing::trace!("Removed entry for key {:?} (extractor returned None)", key);
}
}
if watch_tx.send(state.clone()).is_err() {
tracing::error!("Failed to send update; receiver dropped");
break;
}
}
WatchEvent::Delete(kv) => {
if let Some(key) = key_extractor(&kv) {
state.remove(&key);
tracing::trace!("Removed entry for deleted key {:?}", key);
if watch_tx.send(state.clone()).is_err() {
tracing::error!("Failed to send update; receiver dropped");
break;
}
}
}
}
}
}
}
tracing::info!("TypedPrefixWatcher for prefix '{}' stopped", prefix_str);
});
Ok(TypedPrefixWatcher { rx: watch_rx })
}
/// Watch an etcd prefix and maintain a HashMap of values without field extraction
///
/// This is a simpler version when you want to store the entire deserialized value.
///
/// # Example
/// ```ignore
/// // Watch for TestConfig objects directly
/// let watcher = watch_prefix(
/// etcd_client,
/// "configs/",
/// |kv| Some(kv.lease()), // Use lease_id as key
/// cancellation_token,
/// ).await?;
/// ```
pub async fn watch_prefix<K, V>(
client: EtcdClient,
prefix: impl Into<String>,
key_extractor: impl Fn(&KeyValue) -> Option<K> + Send + 'static,
cancellation_token: CancellationToken,
) -> Result<TypedPrefixWatcher<K, V>>
where
K: Clone + Eq + std::hash::Hash + Send + Sync + Debug + 'static,
V: Clone + DeserializeOwned + Send + Sync + 'static,
{
watch_prefix_with_extraction(
client,
prefix,
key_extractor,
|v: V| Some(v), // Identity function - just return the value
cancellation_token,
)
.await
}
/// Common key extractors for convenience
pub mod key_extractors {
use etcd_client::KeyValue;
/// Extract the lease ID as the key
pub fn lease_id(kv: &KeyValue) -> Option<i64> {
Some(kv.lease())
}
/// Extract the key as a string (without prefix)
pub fn key_string(prefix: &str) -> impl Fn(&KeyValue) -> Option<String> {
let prefix = prefix.to_string();
move |kv: &KeyValue| {
kv.key_str()
.ok()
.map(|k| k.strip_prefix(&prefix).unwrap_or(k).to_string())
}
}
/// Extract the full key as a string
pub fn full_key_string(kv: &KeyValue) -> Option<String> {
kv.key_str().ok().map(|s| s.to_string())
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
// TODO: Make load comparisons and runtime metrics a generic trait so this monitoring
// system is not tied to KV cache concepts, which are LLM-specific. This would allow
// different types of workers to define their own load metrics and busy thresholds.
use crate::component::{Client, InstanceSource};
use crate::traits::events::EventSubscriber;
use crate::traits::DistributedRuntimeProvider;
use crate::utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tokio::sync::watch;
use tokio_stream::StreamExt;
// Constants for monitoring configuration
const KV_METRICS_SUBJECT: &str = "kv_metrics";
const MODEL_ROOT_PATH: &str = "models";
// Internal structs for deserializing metrics events
#[derive(serde::Deserialize)]
struct LoadEvent {
worker_id: i64,
data: ForwardPassMetrics,
}
#[derive(serde::Deserialize)]
struct ForwardPassMetrics {
kv_stats: KvStats,
}
#[derive(serde::Deserialize)]
struct KvStats {
kv_active_blocks: u64,
}
#[derive(serde::Deserialize)]
struct ModelEntry {
runtime_config: Option<RuntimeConfig>,
}
#[derive(serde::Deserialize)]
struct RuntimeConfig {
total_kv_blocks: Option<u64>,
}
/// Worker load monitoring state
#[derive(Clone, Debug)]
pub struct WorkerLoadState {
pub kv_active_blocks: Option<u64>,
pub kv_total_blocks: Option<u64>,
}
impl WorkerLoadState {
pub fn is_busy(&self, threshold: f64) -> bool {
match (self.kv_active_blocks, self.kv_total_blocks) {
(Some(active), Some(total)) if total > 0 => {
(active as f64) > (threshold * total as f64)
}
_ => false,
}
}
}
/// Worker monitor for tracking KV cache usage and busy states
pub struct WorkerMonitor {
client: Arc<Client>,
worker_load_states: Arc<RwLock<HashMap<i64, WorkerLoadState>>>,
busy_threshold: f64,
}
impl WorkerMonitor {
/// Create a new worker monitor with custom threshold
pub fn new_with_threshold(client: Arc<Client>, busy_threshold: f64) -> Self {
Self {
client,
worker_load_states: Arc::new(RwLock::new(HashMap::new())),
busy_threshold,
}
}
/// Get the worker load states for external access
pub fn load_states(&self) -> Arc<RwLock<HashMap<i64, WorkerLoadState>>> {
self.worker_load_states.clone()
}
/// Start background monitoring of worker KV cache usage
pub async fn start_monitoring(&self) -> anyhow::Result<()> {
let endpoint = &self.client.endpoint;
let component = endpoint.component();
let Some(etcd_client) = component.drt().etcd_client() else {
// Static mode, no monitoring needed
return Ok(());
};
let runtime_configs_watcher = watch_prefix_with_extraction(
etcd_client,
MODEL_ROOT_PATH,
key_extractors::lease_id,
|entry: ModelEntry| entry.runtime_config.and_then(|rc| rc.total_kv_blocks),
component.drt().child_token(),
)
.await?;
let mut config_events_rx = runtime_configs_watcher.receiver();
// Subscribe to KV metrics events
let mut kv_metrics_rx = component.namespace().subscribe(KV_METRICS_SUBJECT).await?;
let worker_load_states = self.worker_load_states.clone();
let client = self.client.clone();
let cancellation_token = component.drt().child_token();
let busy_threshold = self.busy_threshold; // Capture threshold for the closure
// Spawn background monitoring task
tokio::spawn(async move {
let mut previous_busy_instances = Vec::new(); // Track previous state
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::debug!("Worker monitoring cancelled");
break;
}
// Handle runtime config updates - now receives full HashMap
_ = config_events_rx.changed() => {
let runtime_configs = config_events_rx.borrow().clone();
let mut states = worker_load_states.write().unwrap();
states.retain(|lease_id, _| runtime_configs.contains_key(lease_id));
// Update worker load states with total blocks
for (lease_id, total_blocks) in runtime_configs.iter() {
let state = states.entry(*lease_id).or_insert(WorkerLoadState {
kv_active_blocks: None,
kv_total_blocks: None,
});
state.kv_total_blocks = Some(*total_blocks);
}
}
// Handle KV metrics updates
kv_event = kv_metrics_rx.next() => {
let Some(event) = kv_event else {
tracing::debug!("KV metrics stream closed");
break;
};
if let Ok(load_event) = serde_json::from_slice::<LoadEvent>(&event.payload) {
let worker_id = load_event.worker_id;
let active_blocks = load_event.data.kv_stats.kv_active_blocks;
// Update worker load state
let mut states = worker_load_states.write().unwrap();
let state = states.entry(worker_id).or_insert(WorkerLoadState {
kv_active_blocks: None,
kv_total_blocks: None,
});
state.kv_active_blocks = Some(active_blocks);
drop(states);
// Recalculate all busy instances and update
let states = worker_load_states.read().unwrap();
let busy_instances: Vec<i64> = states
.iter()
.filter_map(|(&id, state)| {
state.is_busy(busy_threshold).then_some(id)
})
.collect();
drop(states);
// Only update if busy_instances has changed
if busy_instances != previous_busy_instances {
tracing::debug!("Busy instances changed: {:?}", busy_instances);
client.update_free_instances(&busy_instances);
previous_busy_instances = busy_instances;
}
}
}
}
}
tracing::info!("Worker monitoring task exiting");
});
Ok(())
}
}
......@@ -5,6 +5,8 @@ import asyncio
import json
import logging
import os
import random
from typing import Any, Dict
import aiohttp
import pytest
......@@ -22,6 +24,19 @@ SPEEDUP_RATIO = 10.0
NUM_REQUESTS = 100
PORT = 8090 # Starting port for mocker instances
# Shared test payload for all tests
TEST_PAYLOAD: Dict[str, Any] = {
"model": MODEL_NAME,
"messages": [
{
"role": "user",
"content": "In a quiet meadow tucked between rolling hills, a plump gray rabbit nibbled on clover beneath the shade of a gnarled oak tree. Its ears twitched at the faint rustle of leaves, but it remained calm, confident in the safety of its burrow just a few hops away. The late afternoon sun warmed its fur, and tiny dust motes danced in the golden light as bees hummed lazily nearby. Though the rabbit lived a simple life, every day was an adventure of scents, shadows, and snacks—an endless search for the tastiest patch of greens and the softest spot to nap.",
}
],
"stream": True,
"max_tokens": 10,
}
class MockerProcess(ManagedProcess):
"""Manages a single mocker engine instance"""
......@@ -88,6 +103,89 @@ class KVRouterProcess(ManagedProcess):
super().__exit__(exc_type, exc_val, exc_tb)
async def send_request_with_retry(url: str, payload: dict, max_retries: int = 4):
"""Send a single request with exponential backoff retry"""
wait_time = 1 # Start with 1 second
for attempt in range(max_retries + 1):
await asyncio.sleep(wait_time)
try:
async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload) as response:
if response.status == 200:
# Read the response to ensure it's valid
async for _ in response.content:
pass
logger.info(f"First request succeeded on attempt {attempt + 1}")
return True
else:
logger.warning(
f"Attempt {attempt + 1} failed with status {response.status}"
)
except Exception as e:
logger.warning(f"Attempt {attempt + 1} failed with error: {e}")
if attempt < max_retries:
wait_time *= 2 # Double the wait time
return False
async def send_concurrent_requests(urls: list, payload: dict, num_requests: int):
"""Send multiple requests concurrently, alternating between URLs if multiple provided"""
# First, send test requests with retry to ensure all systems are ready
for i, url in enumerate(urls):
logger.info(f"Sending initial test request to URL {i} ({url}) with retry...")
if not await send_request_with_retry(url, payload):
raise RuntimeError(f"Failed to connect to URL {i} after multiple retries")
async def send_single_request(session: aiohttp.ClientSession, request_id: int):
# Alternate between URLs based on request_id
url = urls[request_id % len(urls)]
url_index = request_id % len(urls)
try:
async with session.post(url, json=payload) as response:
if response.status != 200:
logger.error(
f"Request {request_id} to URL {url_index} failed with status {response.status}"
)
return False
# For streaming responses, read the entire stream
chunks = []
async for line in response.content:
if line:
chunks.append(line)
logger.debug(
f"Request {request_id} to URL {url_index} completed with {len(chunks)} chunks"
)
return True
except Exception as e:
logger.error(
f"Request {request_id} to URL {url_index} failed with error: {e}"
)
return False
# Send all requests at once
async with aiohttp.ClientSession() as session:
tasks = [send_single_request(session, i) for i in range(num_requests)]
results = await asyncio.gather(*tasks)
successful = sum(1 for r in results if r)
failed = sum(1 for r in results if not r)
logger.info(f"Completed all requests: {successful} successful, {failed} failed")
assert (
successful == num_requests
), f"Expected {num_requests} successful requests, got {successful}"
logger.info(f"All {num_requests} requests completed successfully")
@pytest.mark.pre_merge
def test_mocker_kv_router(request, runtime_services):
"""
......@@ -128,26 +226,13 @@ def test_mocker_kv_router(request, runtime_services):
for mocker in mocker_processes:
mocker.__enter__()
# Send test requests
test_payload = {
"model": MODEL_NAME,
"messages": [
{
"role": "user",
"content": "In a quiet meadow tucked between rolling hills, a plump gray rabbit nibbled on clover beneath the shade of a gnarled oak tree. Its ears twitched at the faint rustle of leaves, but it remained calm, confident in the safety of its burrow just a few hops away. The late afternoon sun warmed its fur, and tiny dust motes danced in the golden light as bees hummed lazily nearby. Though the rabbit lived a simple life, every day was an adventure of scents, shadows, and snacks—an endless search for the tastiest patch of greens and the softest spot to nap.",
}
],
"stream": True,
"max_tokens": 10,
}
# Use async to send requests concurrently for better performance
asyncio.run(
send_concurrent_requests(
[
f"http://localhost:{frontend_port}/v1/chat/completions"
], # Pass as list
test_payload,
TEST_PAYLOAD,
NUM_REQUESTS,
)
)
......@@ -209,19 +294,6 @@ def test_mocker_two_kv_router(request, runtime_services):
for mocker in mocker_processes:
mocker.__enter__()
# Send test requests
test_payload = {
"model": MODEL_NAME,
"messages": [
{
"role": "user",
"content": "In a quiet meadow tucked between rolling hills, a plump gray rabbit nibbled on clover beneath the shade of a gnarled oak tree. Its ears twitched at the faint rustle of leaves, but it remained calm, confident in the safety of its burrow just a few hops away. The late afternoon sun warmed its fur, and tiny dust motes danced in the golden light as bees hummed lazily nearby. Though the rabbit lived a simple life, every day was an adventure of scents, shadows, and snacks—an endless search for the tastiest patch of greens and the softest spot to nap.",
}
],
"stream": True,
"max_tokens": 10,
}
# Build URLs for both routers
router_urls = [
f"http://localhost:{port}/v1/chat/completions" for port in router_ports
......@@ -231,7 +303,7 @@ def test_mocker_two_kv_router(request, runtime_services):
asyncio.run(
send_concurrent_requests(
router_urls,
test_payload,
TEST_PAYLOAD,
NUM_REQUESTS,
)
)
......@@ -253,84 +325,177 @@ def test_mocker_two_kv_router(request, runtime_services):
os.unlink(mocker_args_file)
async def send_request_with_retry(url: str, payload: dict, max_retries: int = 4):
"""Send a single request with exponential backoff retry"""
wait_time = 1 # Start with 1 second
@pytest.mark.pre_merge
@pytest.mark.skip(reason="Flaky, temporarily disabled")
def test_mocker_kv_router_overload_503(request, runtime_services):
"""
Test that KV router returns 503 when all workers are busy.
This test uses limited resources to intentionally trigger the overload condition.
"""
# runtime_services starts etcd and nats
logger.info("Starting mocker KV router overload test for 503 status")
# Create mocker args file with limited resources
mocker_args = {
"speedup_ratio": 10,
"block_size": 4, # Smaller block size
"num_gpu_blocks": 64, # Limited GPU blocks to exhaust quickly
}
mocker_args_file = os.path.join(request.node.name, "mocker_args_overload.json")
with open(mocker_args_file, "w") as f:
json.dump(mocker_args, f)
for attempt in range(max_retries + 1):
await asyncio.sleep(wait_time)
try:
async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload) as response:
if response.status == 200:
# Read the response to ensure it's valid
async for _ in response.content:
pass
logger.info(f"First request succeeded on attempt {attempt + 1}")
return True
else:
logger.warning(
f"Attempt {attempt + 1} failed with status {response.status}"
# Start KV router (frontend) with limited block size
frontend_port = PORT + 10 # Use different port to avoid conflicts
logger.info(
f"Starting KV router frontend on port {frontend_port} with limited resources"
)
except Exception as e:
logger.warning(f"Attempt {attempt + 1} failed with error: {e}")
if attempt < max_retries:
wait_time *= 2 # Double the wait time
# Custom command for router with limited block size
command = [
"python",
"-m",
"dynamo.frontend",
"--busy-threshold",
"0.2",
"--kv-cache-block-size",
"4", # Match the mocker's block size
"--router-mode",
"kv",
"--http-port",
str(frontend_port),
]
return False
kv_router = ManagedProcess(
command=command,
timeout=60,
display_output=True,
health_check_ports=[frontend_port],
health_check_urls=[
(
f"http://localhost:{frontend_port}/v1/models",
lambda r: r.status_code == 200,
)
],
log_dir=request.node.name,
terminate_existing=False,
)
kv_router.__enter__()
# Start single mocker instance with limited resources
endpoint = "dyn://test-namespace.mocker.generate"
logger.info(
f"Starting single mocker instance with limited resources on endpoint {endpoint}"
)
async def send_concurrent_requests(urls: list, payload: dict, num_requests: int):
"""Send multiple requests concurrently, alternating between URLs if multiple provided"""
mocker = MockerProcess(request, endpoint, mocker_args_file)
mocker.__enter__()
# First, send test requests with retry to ensure all systems are ready
for i, url in enumerate(urls):
logger.info(f"Sending initial test request to URL {i} ({url}) with retry...")
if not await send_request_with_retry(url, payload):
raise RuntimeError(f"Failed to connect to URL {i} after multiple retries")
url = f"http://localhost:{frontend_port}/v1/chat/completions"
async def send_single_request(session: aiohttp.ClientSession, request_id: int):
# Alternate between URLs based on request_id
url = urls[request_id % len(urls)]
url_index = request_id % len(urls)
# Custom payload for 503 test with more tokens to consume resources
test_payload_503 = {
**TEST_PAYLOAD,
"max_tokens": 50, # Longer output to consume more blocks
}
# First, send one request with retry to ensure system is ready
logger.info("Sending initial request to ensure system is ready...")
asyncio.run(send_concurrent_requests([url], test_payload_503, 1))
# Now send 50 concurrent requests to exhaust resources, then verify 503
logger.info("Sending 50 concurrent requests to exhaust resources...")
async def exhaust_resources_and_verify_503():
async with aiohttp.ClientSession() as session:
# Start 50 long-running requests concurrently
tasks = []
for i in range(50):
# Create unique shuffled content for each request
content_words = TEST_PAYLOAD["messages"][0]["content"].split()
random.shuffle(content_words)
shuffled_content = " ".join(content_words)
# Create unique payload for this request
unique_payload = {
**TEST_PAYLOAD,
"max_tokens": 50,
"messages": [
{**TEST_PAYLOAD["messages"][0], "content": shuffled_content}
],
}
async def send_long_request(req_id, payload):
try:
async with session.post(url, json=payload) as response:
if response.status != 200:
logger.error(
f"Request {request_id} to URL {url_index} failed with status {response.status}"
if response.status == 200:
# Don't read the response fully, just hold the connection
await asyncio.sleep(
10
) # Hold connection for 10 seconds
return True
else:
logger.info(
f"Request {req_id} got status {response.status}"
)
return False
except Exception as e:
logger.info(f"Request {req_id} failed: {e}")
return False
# For streaming responses, read the entire stream
chunks = []
async for line in response.content:
if line:
chunks.append(line)
logger.debug(
f"Request {request_id} to URL {url_index} completed with {len(chunks)} chunks"
tasks.append(
asyncio.create_task(send_long_request(i, unique_payload))
)
return True
except Exception as e:
# Wait briefly to ensure requests are in-flight
await asyncio.sleep(0.2)
# Now send one more request that should get 503
logger.info("Sending additional request that should receive 503...")
try:
async with session.post(url, json=test_payload_503) as response:
status_code = response.status
if status_code == 503:
body = await response.json()
logger.info(f"Got expected 503 response: {body}")
assert "Service temporarily unavailable" in body.get(
"error", ""
) or "All workers are busy" in body.get(
"error", ""
), f"Expected service overload error message, got: {body}"
return True
else:
logger.error(f"Expected 503 but got {status_code}")
if status_code == 200:
logger.error(
f"Request {request_id} to URL {url_index} failed with error: {e}"
"Request unexpectedly succeeded when it should have been rejected"
)
return False
except Exception as e:
logger.error(f"Failed to send overload test request: {e}")
return False
finally:
# Cancel all background tasks
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
# Send all requests at once
async with aiohttp.ClientSession() as session:
tasks = [send_single_request(session, i) for i in range(num_requests)]
results = await asyncio.gather(*tasks)
# Run the test
success = asyncio.run(exhaust_resources_and_verify_503())
assert success, "Failed to verify 503 response when resources are exhausted"
successful = sum(1 for r in results if r)
failed = sum(1 for r in results if not r)
logger.info("Successfully verified 503 response when all workers are busy")
logger.info(f"Completed all requests: {successful} successful, {failed} failed")
finally:
# Clean up
if "kv_router" in locals():
kv_router.__exit__(None, None, None)
assert (
successful == num_requests
), f"Expected {num_requests} successful requests, got {successful}"
logger.info(f"All {num_requests} requests completed successfully")
if "mocker" in locals():
mocker.__exit__(None, None, None)
if os.path.exists(mocker_args_file):
os.unlink(mocker_args_file)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment