Unverified Commit 85d83108 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: router-level request rejection (#2465)

parent 1945f599
...@@ -143,6 +143,12 @@ def parse_args(): ...@@ -143,6 +143,12 @@ def parse_args():
default=False, default=False,
help="KV Router: Enable replica synchronization across multiple router instances. When true, routers will publish and subscribe to events to maintain consistent state.", help="KV Router: Enable replica synchronization across multiple router instances. When true, routers will publish and subscribe to events to maintain consistent state.",
) )
parser.add_argument(
"--busy-threshold",
type=float,
default=None,
help="Threshold (0.0-1.0) for determining when a worker is considered busy based on KV cache usage. If not set, busy detection is disabled.",
)
parser.add_argument( parser.add_argument(
"--static-endpoint", "--static-endpoint",
type=validate_static_endpoint, type=validate_static_endpoint,
...@@ -205,7 +211,9 @@ async def async_main(): ...@@ -205,7 +211,9 @@ async def async_main():
kwargs = { kwargs = {
"http_port": flags.http_port, "http_port": flags.http_port,
"kv_cache_block_size": flags.kv_cache_block_size, "kv_cache_block_size": flags.kv_cache_block_size,
"router_config": RouterConfig(router_mode, kv_router_config), "router_config": RouterConfig(
router_mode, kv_router_config, flags.busy_threshold
),
} }
if flags.static_endpoint: if flags.static_endpoint:
......
...@@ -60,16 +60,22 @@ impl KvRouterConfig { ...@@ -60,16 +60,22 @@ impl KvRouterConfig {
pub struct RouterConfig { pub struct RouterConfig {
router_mode: RouterMode, router_mode: RouterMode,
kv_router_config: KvRouterConfig, kv_router_config: KvRouterConfig,
busy_threshold: Option<f64>,
} }
#[pymethods] #[pymethods]
impl RouterConfig { impl RouterConfig {
#[new] #[new]
#[pyo3(signature = (mode, config=None))] #[pyo3(signature = (mode, config=None, busy_threshold=None))]
pub fn new(mode: RouterMode, config: Option<KvRouterConfig>) -> Self { pub fn new(
mode: RouterMode,
config: Option<KvRouterConfig>,
busy_threshold: Option<f64>,
) -> Self {
Self { Self {
router_mode: mode, router_mode: mode,
kv_router_config: config.unwrap_or_default(), kv_router_config: config.unwrap_or_default(),
busy_threshold,
} }
} }
} }
...@@ -79,6 +85,7 @@ impl From<RouterConfig> for RsRouterConfig { ...@@ -79,6 +85,7 @@ impl From<RouterConfig> for RsRouterConfig {
RsRouterConfig { RsRouterConfig {
router_mode: rc.router_mode.into(), router_mode: rc.router_mode.into(),
kv_router_config: rc.kv_router_config.inner, kv_router_config: rc.kv_router_config.inner,
busy_threshold: rc.busy_threshold,
} }
} }
} }
......
...@@ -50,6 +50,7 @@ pub struct ModelWatcher { ...@@ -50,6 +50,7 @@ pub struct ModelWatcher {
notify_on_model: Notify, notify_on_model: Notify,
model_update_tx: Option<Sender<ModelUpdate>>, model_update_tx: Option<Sender<ModelUpdate>>,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
busy_threshold: Option<f64>,
} }
const ALL_MODEL_TYPES: &[ModelType] = const ALL_MODEL_TYPES: &[ModelType] =
...@@ -61,6 +62,7 @@ impl ModelWatcher { ...@@ -61,6 +62,7 @@ impl ModelWatcher {
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
router_mode: RouterMode, router_mode: RouterMode,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
busy_threshold: Option<f64>,
) -> ModelWatcher { ) -> ModelWatcher {
Self { Self {
manager: model_manager, manager: model_manager,
...@@ -69,6 +71,7 @@ impl ModelWatcher { ...@@ -69,6 +71,7 @@ impl ModelWatcher {
notify_on_model: Notify::new(), notify_on_model: Notify::new(),
model_update_tx: None, model_update_tx: None,
kv_router_config, kv_router_config,
busy_threshold,
} }
} }
...@@ -316,21 +319,31 @@ impl ModelWatcher { ...@@ -316,21 +319,31 @@ impl ModelWatcher {
None None
}; };
let chat_engine = let chat_engine = entrypoint::build_routed_pipeline::<
entrypoint::build_routed_pipeline::< NvCreateChatCompletionRequest,
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
NvCreateChatCompletionStreamResponse, >(
>(&card, &client, self.router_mode, kv_chooser.clone()) &card,
.await?; &client,
self.router_mode,
self.busy_threshold,
kv_chooser.clone(),
)
.await?;
self.manager self.manager
.add_chat_completions_model(&model_entry.name, chat_engine)?; .add_chat_completions_model(&model_entry.name, chat_engine)?;
let completions_engine = let completions_engine = entrypoint::build_routed_pipeline::<
entrypoint::build_routed_pipeline::< NvCreateCompletionRequest,
NvCreateCompletionRequest, NvCreateCompletionResponse,
NvCreateCompletionResponse, >(
>(&card, &client, self.router_mode, kv_chooser) &card,
.await?; &client,
self.router_mode,
self.busy_threshold,
kv_chooser,
)
.await?;
self.manager self.manager
.add_completions_model(&model_entry.name, completions_engine)?; .add_completions_model(&model_entry.name, completions_engine)?;
} }
...@@ -338,7 +351,9 @@ impl ModelWatcher { ...@@ -338,7 +351,9 @@ impl ModelWatcher {
let push_router = PushRouter::< let push_router = PushRouter::<
NvCreateChatCompletionRequest, NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>, Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client(client, Default::default()) >::from_client_with_threshold(
client, Default::default(), self.busy_threshold
)
.await?; .await?;
let engine = Arc::new(push_router); let engine = Arc::new(push_router);
self.manager self.manager
...@@ -348,7 +363,9 @@ impl ModelWatcher { ...@@ -348,7 +363,9 @@ impl ModelWatcher {
let push_router = PushRouter::< let push_router = PushRouter::<
NvCreateCompletionRequest, NvCreateCompletionRequest,
Annotated<NvCreateCompletionResponse>, Annotated<NvCreateCompletionResponse>,
>::from_client(client, Default::default()) >::from_client_with_threshold(
client, Default::default(), self.busy_threshold
)
.await?; .await?;
let engine = Arc::new(push_router); let engine = Arc::new(push_router);
self.manager self.manager
...@@ -374,7 +391,9 @@ impl ModelWatcher { ...@@ -374,7 +391,9 @@ impl ModelWatcher {
let router = PushRouter::< let router = PushRouter::<
PreprocessedEmbeddingRequest, PreprocessedEmbeddingRequest,
Annotated<EmbeddingsEngineOutput>, Annotated<EmbeddingsEngineOutput>,
>::from_client(client, self.router_mode) >::from_client_with_threshold(
client, self.router_mode, self.busy_threshold
)
.await?; .await?;
// Note: Embeddings don't need KV routing complexity // Note: Embeddings don't need KV routing complexity
......
...@@ -21,6 +21,7 @@ use crate::{ ...@@ -21,6 +21,7 @@ use crate::{
pub struct RouterConfig { pub struct RouterConfig {
pub router_mode: RouterMode, pub router_mode: RouterMode,
pub kv_router_config: KvRouterConfig, pub kv_router_config: KvRouterConfig,
pub busy_threshold: Option<f64>,
} }
impl RouterConfig { impl RouterConfig {
...@@ -28,8 +29,14 @@ impl RouterConfig { ...@@ -28,8 +29,14 @@ impl RouterConfig {
Self { Self {
router_mode, router_mode,
kv_router_config, kv_router_config,
busy_threshold: None,
} }
} }
pub fn with_busy_threshold(mut self, threshold: Option<f64>) -> Self {
self.busy_threshold = threshold;
self
}
} }
#[derive(Clone)] #[derive(Clone)]
......
...@@ -71,6 +71,7 @@ pub async fn prepare_engine( ...@@ -71,6 +71,7 @@ pub async fn prepare_engine(
model_manager.clone(), model_manager.clone(),
dynamo_runtime::pipeline::RouterMode::RoundRobin, dynamo_runtime::pipeline::RouterMode::RoundRobin,
None, None,
None,
)); ));
let models_watcher = etcd_client.kv_get_and_watch_prefix(MODEL_ROOT_PATH).await?; let models_watcher = etcd_client.kv_get_and_watch_prefix(MODEL_ROOT_PATH).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve(); let (_prefix, _watcher, receiver) = models_watcher.dissolve();
...@@ -133,7 +134,7 @@ pub async fn prepare_engine( ...@@ -133,7 +134,7 @@ pub async fn prepare_engine(
let chat_engine = entrypoint::build_routed_pipeline::< let chat_engine = entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest, NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse, NvCreateChatCompletionStreamResponse,
>(card, &client, router_mode, kv_chooser.clone()) >(card, &client, router_mode, None, kv_chooser.clone())
.await?; .await?;
let service_name = local_model.service_name().to_string(); let service_name = local_model.service_name().to_string();
...@@ -216,6 +217,7 @@ pub async fn build_routed_pipeline<Req, Resp>( ...@@ -216,6 +217,7 @@ pub async fn build_routed_pipeline<Req, Resp>(
card: &ModelDeploymentCard, card: &ModelDeploymentCard,
client: &Client, client: &Client,
router_mode: RouterMode, router_mode: RouterMode,
busy_threshold: Option<f64>,
chooser: Option<Arc<KvRouter>>, chooser: Option<Arc<KvRouter>>,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>> ) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where where
...@@ -232,11 +234,13 @@ where ...@@ -232,11 +234,13 @@ where
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator(); let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator(); let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let migration = Migration::from_mdc(card.clone()).await?.into_operator(); let migration = Migration::from_mdc(card.clone()).await?.into_operator();
let router = PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client( let router =
client.clone(), PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
router_mode, client.clone(),
) router_mode,
.await?; busy_threshold,
)
.await?;
let service_backend = match router_mode { let service_backend = match router_mode {
RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => { RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => {
ServiceBackend::from_engine(Arc::new(router)) ServiceBackend::from_engine(Arc::new(router))
......
...@@ -66,6 +66,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -66,6 +66,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
MODEL_ROOT_PATH, MODEL_ROOT_PATH,
router_config.router_mode, router_config.router_mode,
Some(router_config.kv_router_config), Some(router_config.kv_router_config),
router_config.busy_threshold,
Arc::new(http_service.clone()), Arc::new(http_service.clone()),
) )
.await?; .await?;
...@@ -109,14 +110,14 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -109,14 +110,14 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let chat_engine = entrypoint::build_routed_pipeline::< let chat_engine = entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest, NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse, NvCreateChatCompletionStreamResponse,
>(card, &client, router_mode, kv_chooser.clone()) >(card, &client, router_mode, None, kv_chooser.clone())
.await?; .await?;
manager.add_chat_completions_model(local_model.display_name(), chat_engine)?; manager.add_chat_completions_model(local_model.display_name(), chat_engine)?;
let completions_engine = entrypoint::build_routed_pipeline::< let completions_engine = entrypoint::build_routed_pipeline::<
NvCreateCompletionRequest, NvCreateCompletionRequest,
NvCreateCompletionResponse, NvCreateCompletionResponse,
>(card, &client, router_mode, kv_chooser) >(card, &client, router_mode, None, kv_chooser)
.await?; .await?;
manager.add_completions_model(local_model.display_name(), completions_engine)?; manager.add_completions_model(local_model.display_name(), completions_engine)?;
...@@ -188,6 +189,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -188,6 +189,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
/// Spawns a task that watches for new models in etcd at network_prefix, /// Spawns a task that watches for new models in etcd at network_prefix,
/// and registers them with the ModelManager so that the HTTP service can use them. /// and registers them with the ModelManager so that the HTTP service can use them.
#[allow(clippy::too_many_arguments)]
async fn run_watcher( async fn run_watcher(
runtime: DistributedRuntime, runtime: DistributedRuntime,
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
...@@ -195,9 +197,16 @@ async fn run_watcher( ...@@ -195,9 +197,16 @@ async fn run_watcher(
network_prefix: &str, network_prefix: &str,
router_mode: RouterMode, router_mode: RouterMode,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
busy_threshold: Option<f64>,
http_service: Arc<HttpService>, http_service: Arc<HttpService>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let mut watch_obj = ModelWatcher::new(runtime, model_manager, router_mode, kv_router_config); let mut watch_obj = ModelWatcher::new(
runtime,
model_manager,
router_mode,
kv_router_config,
busy_threshold,
);
tracing::info!("Watching for remote model at {network_prefix}"); tracing::info!("Watching for remote model at {network_prefix}");
let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?; let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve(); let (_prefix, _watcher, receiver) = models_watcher.dissolve();
......
...@@ -108,6 +108,24 @@ impl ErrorMessage { ...@@ -108,6 +108,24 @@ impl ErrorMessage {
/// If successful, it will return the [`HttpError`] as an [`ErrorMessage::internal_server_error`] /// If successful, it will return the [`HttpError`] as an [`ErrorMessage::internal_server_error`]
/// with the details of the error. /// with the details of the error.
pub fn from_anyhow(err: anyhow::Error, alt_msg: &str) -> ErrorResponse { pub fn from_anyhow(err: anyhow::Error, alt_msg: &str) -> ErrorResponse {
// First check for PipelineError::ServiceOverloaded
if let Some(pipeline_err) =
err.downcast_ref::<dynamo_runtime::pipeline::error::PipelineError>()
{
if matches!(
pipeline_err,
dynamo_runtime::pipeline::error::PipelineError::ServiceOverloaded(_)
) {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorMessage {
error: pipeline_err.to_string(),
}),
);
}
}
// Then check for HttpError
match err.downcast::<HttpError>() { match err.downcast::<HttpError>() {
Ok(http_error) => ErrorMessage::from_http_error(http_error), Ok(http_error) => ErrorMessage::from_http_error(http_error),
Err(err) => ErrorMessage::internal_server_error(&format!("{alt_msg}: {err}")), Err(err) => ErrorMessage::internal_server_error(&format!("{alt_msg}: {err}")),
...@@ -1150,6 +1168,22 @@ mod tests { ...@@ -1150,6 +1168,22 @@ mod tests {
); );
} }
#[test]
fn test_service_overloaded_error_response_from_anyhow() {
use dynamo_runtime::pipeline::error::PipelineError;
let err: anyhow::Error = PipelineError::ServiceOverloaded(
"All workers are busy, please retry later".to_string(),
)
.into();
let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE);
assert_eq!(
response.error,
"Service temporarily unavailable: All workers are busy, please retry later"
);
}
#[test] #[test]
fn test_validate_input_is_text_only_accepts_text() { fn test_validate_input_is_text_only_accepts_text() {
let request = make_base_request(); let request = make_base_request();
......
...@@ -29,13 +29,13 @@ pub mod scoring; ...@@ -29,13 +29,13 @@ pub mod scoring;
pub mod sequence; pub mod sequence;
use crate::{ use crate::{
discovery::{ModelEntry, MODEL_ROOT_PATH},
kv_router::{ kv_router::{
approx::ApproxKvIndexer, approx::ApproxKvIndexer,
indexer::{ indexer::{
compute_block_hash_for_seq, compute_seq_hash_for_block, KvIndexer, KvIndexerInterface, compute_block_hash_for_seq, compute_seq_hash_for_block, KvIndexer, KvIndexerInterface,
KvRouterError, OverlapScores, RouterEvent, KvRouterError, OverlapScores, RouterEvent,
}, },
metrics_aggregator::watch_model_runtime_configs,
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult}, protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest}, scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
scoring::ProcessedEndpoints, scoring::ProcessedEndpoints,
...@@ -177,14 +177,25 @@ impl KvRouter { ...@@ -177,14 +177,25 @@ impl KvRouter {
} }
}; };
// Create runtime config watcher // Create runtime config watcher using the generic etcd watcher
// TODO: Migrate to discovery_client() once it exposes kv_get_and_watch_prefix functionality // TODO: Migrate to discovery_client() once it exposes kv_get_and_watch_prefix functionality
let etcd_client = component let etcd_client = component
.drt() .drt()
.etcd_client() .etcd_client()
.expect("Cannot KV route without etcd client"); .expect("Cannot KV route without etcd client");
let runtime_configs_rx =
watch_model_runtime_configs(etcd_client, cancellation_token.clone()).await?; use dynamo_runtime::utils::typed_prefix_watcher::{
key_extractors, watch_prefix_with_extraction,
};
let runtime_configs_watcher = watch_prefix_with_extraction(
etcd_client,
MODEL_ROOT_PATH,
key_extractors::lease_id,
|model_entry: ModelEntry| model_entry.runtime_config,
cancellation_token.clone(),
)
.await?;
let runtime_configs_rx = runtime_configs_watcher.receiver();
let indexer = if kv_router_config.use_kv_events { let indexer = if kv_router_config.use_kv_events {
Indexer::KvIndexer(KvIndexer::new(cancellation_token.clone(), block_size)) Indexer::KvIndexer(KvIndexer::new(cancellation_token.clone(), block_size))
......
...@@ -18,14 +18,10 @@ use std::sync::Once; ...@@ -18,14 +18,10 @@ use std::sync::Once;
pub use crate::kv_router::protocols::{ForwardPassMetrics, LoadMetrics, PredictiveLoadMetrics}; pub use crate::kv_router::protocols::{ForwardPassMetrics, LoadMetrics, PredictiveLoadMetrics};
use crate::kv_router::KV_METRICS_ENDPOINT; use crate::kv_router::KV_METRICS_ENDPOINT;
use crate::discovery::{ModelEntry, MODEL_ROOT_PATH};
use crate::kv_router::scoring::Endpoint; use crate::kv_router::scoring::Endpoint;
use crate::kv_router::ProcessedEndpoints; use crate::kv_router::ProcessedEndpoints;
use crate::local_model::runtime_config::ModelRuntimeConfig;
use dynamo_runtime::component::Component; use dynamo_runtime::component::Component;
use dynamo_runtime::transports::etcd::{Client as EtcdClient, WatchEvent};
use dynamo_runtime::{service::EndpointInfo, utils::Duration, Result}; use dynamo_runtime::{service::EndpointInfo, utils::Duration, Result};
use std::collections::HashMap;
use tokio::sync::watch; use tokio::sync::watch;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
...@@ -212,71 +208,3 @@ pub async fn collect_endpoints_task( ...@@ -212,71 +208,3 @@ pub async fn collect_endpoints_task(
} }
} }
} }
pub async fn watch_model_runtime_configs(
etcd_client: EtcdClient,
cancellation_token: CancellationToken,
) -> Result<watch::Receiver<HashMap<i64, ModelRuntimeConfig>>> {
let (watch_tx, watch_rx) = watch::channel(HashMap::new());
let prefix_watcher = etcd_client.kv_get_and_watch_prefix(MODEL_ROOT_PATH).await?;
let (_prefix, _watcher, mut events_rx) = prefix_watcher.dissolve();
tokio::spawn(async move {
let mut runtime_configs: HashMap<i64, ModelRuntimeConfig> = HashMap::new();
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::debug!("Runtime config watcher cancelled");
break;
}
event = events_rx.recv() => {
let Some(event) = event else {
tracing::debug!("Runtime config watch stream closed");
break;
};
match event {
WatchEvent::Put(kv) => {
let Ok(model_entry) = serde_json::from_slice::<ModelEntry>(kv.value()) else {
tracing::warn!(
"Failed to parse ModelEntry from etcd. Key: {}",
kv.key_str().unwrap_or("<invalid>")
);
continue;
};
let lease_id = kv.lease();
if let Some(runtime_config) = model_entry.runtime_config {
runtime_configs.insert(lease_id, runtime_config);
tracing::trace!("Updated runtime config for lease_id: {}", lease_id);
} else {
runtime_configs.remove(&lease_id);
tracing::trace!("Removed runtime config (no config in ModelEntry)");
}
if watch_tx.send(runtime_configs.clone()).is_err() {
tracing::error!("Failed to send runtime configs update; receiver dropped");
break;
}
}
WatchEvent::Delete(kv) => {
let lease_id = kv.lease();
runtime_configs.remove(&lease_id);
tracing::trace!("Removed runtime config for deleted entry");
if watch_tx.send(runtime_configs.clone()).is_err() {
tracing::error!("Failed to send runtime configs update; receiver dropped");
break;
}
}
}
}
}
}
});
Ok(watch_rx)
}
...@@ -503,14 +503,16 @@ impl WorkerMetricsPublisher { ...@@ -503,14 +503,16 @@ impl WorkerMetricsPublisher {
let handler = Arc::new(KvLoadEndpointHandler::new(metrics_rx.clone())); let handler = Arc::new(KvLoadEndpointHandler::new(metrics_rx.clone()));
let handler = Ingress::for_engine(handler)?; let handler = Ingress::for_engine(handler)?;
// let worker_id = component let worker_id = component
// .drt() .drt()
// .primary_lease() .primary_lease()
// .map(|lease| lease.id()) .map(|lease| lease.id())
// .unwrap_or_else(|| { .unwrap_or_else(|| {
// tracing::warn!("Component is static, assuming worker_id of 0"); tracing::warn!("Component is static, assuming worker_id of 0");
// 0 0
// }); });
self.start_nats_metrics_publishing(component.namespace().clone(), worker_id);
component component
.endpoint(KV_METRICS_ENDPOINT) .endpoint(KV_METRICS_ENDPOINT)
......
...@@ -42,8 +42,8 @@ use futures::StreamExt; ...@@ -42,8 +42,8 @@ use futures::StreamExt;
use rand::Rng; use rand::Rng;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, Mutex, OnceCell}; use tokio::sync::{mpsc, Mutex, OnceCell};
use tokio::time::{interval, Duration};
use tokio_stream::wrappers::ReceiverStream; use tokio_stream::wrappers::ReceiverStream;
use uuid::Uuid; use uuid::Uuid;
...@@ -174,7 +174,7 @@ impl MockVllmEngine { ...@@ -174,7 +174,7 @@ impl MockVllmEngine {
(schedulers, kv_event_receivers) (schedulers, kv_event_receivers)
} }
/// Start background tasks to poll and publish metrics every second /// Start background tasks to publish metrics on change
async fn start_metrics_publishing( async fn start_metrics_publishing(
schedulers: &[Scheduler], schedulers: &[Scheduler],
component: Option<Component>, component: Option<Component>,
...@@ -202,19 +202,18 @@ impl MockVllmEngine { ...@@ -202,19 +202,18 @@ impl MockVllmEngine {
tracing::info!("Starting metrics background tasks"); tracing::info!("Starting metrics background tasks");
for (dp_rank, scheduler) in schedulers.iter().enumerate() { for (dp_rank, scheduler) in schedulers.iter().enumerate() {
let scheduler = scheduler.clone(); let mut metrics_rx = scheduler.metrics_receiver();
let publisher = metrics_publisher.clone(); let publisher = metrics_publisher.clone();
let dp_rank = dp_rank as u32; let dp_rank = dp_rank as u32;
let cancel_token = cancel_token.clone(); let cancel_token = cancel_token.clone();
tokio::spawn(async move { tokio::spawn(async move {
let mut interval = interval(Duration::from_millis(100));
loop { loop {
tokio::select! { tokio::select! {
_ = interval.tick() => { // Watch for metrics changes
// Get metrics from scheduler Ok(_) = metrics_rx.changed() => {
let metrics = scheduler.get_forward_pass_metrics().await; // Get the latest metrics
let metrics = metrics_rx.borrow().clone();
// Publish metrics // Publish metrics
if let Err(e) = publisher.publish(Arc::new(metrics)) { if let Err(e) = publisher.publish(Arc::new(metrics)) {
...@@ -568,7 +567,7 @@ mod integration_tests { ...@@ -568,7 +567,7 @@ mod integration_tests {
let engine = MockVllmEngine::new(args); let engine = MockVllmEngine::new(args);
engine.start(test_component.clone()).await?; engine.start(test_component.clone()).await?;
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; tokio::time::sleep(Duration::from_millis(500)).await;
let engine = Arc::new(engine); let engine = Arc::new(engine);
tracing::info!("✓ MockVllmEngine created with DP_SIZE: {DP_SIZE}"); tracing::info!("✓ MockVllmEngine created with DP_SIZE: {DP_SIZE}");
...@@ -598,7 +597,7 @@ mod integration_tests { ...@@ -598,7 +597,7 @@ mod integration_tests {
tracing::info!("✓ Server started in background"); tracing::info!("✓ Server started in background");
// Give server time to start // Give server time to start
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; tokio::time::sleep(Duration::from_millis(500)).await;
tracing::info!("✓ Server startup delay completed"); tracing::info!("✓ Server startup delay completed");
// Print all registered instances from etcd // Print all registered instances from etcd
...@@ -733,7 +732,7 @@ mod integration_tests { ...@@ -733,7 +732,7 @@ mod integration_tests {
cancel_token, cancel_token,
) )
.await; .await;
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; tokio::time::sleep(Duration::from_millis(500)).await;
let processed_endpoints = metrics_aggregator.get_endpoints(); let processed_endpoints = metrics_aggregator.get_endpoints();
tracing::info!( tracing::info!(
......
...@@ -250,11 +250,10 @@ impl SchedulerState { ...@@ -250,11 +250,10 @@ impl SchedulerState {
/// Manages scheduling of requests using KvManager resources /// Manages scheduling of requests using KvManager resources
#[derive(Clone)] #[derive(Clone)]
pub struct Scheduler { pub struct Scheduler {
dp_rank: Option<u32>,
state: Arc<Mutex<SchedulerState>>, state: Arc<Mutex<SchedulerState>>,
kv_manager: Arc<Mutex<KvManager>>, kv_manager: Arc<Mutex<KvManager>>,
request_tx: mpsc::UnboundedSender<DirectRequest>, request_tx: mpsc::UnboundedSender<DirectRequest>,
hit_rates: Arc<Mutex<VecDeque<f32>>>, metrics_rx: tokio::sync::watch::Receiver<ForwardPassMetrics>,
} }
impl Scheduler { impl Scheduler {
...@@ -292,13 +291,16 @@ impl Scheduler { ...@@ -292,13 +291,16 @@ impl Scheduler {
// Create channel for request handling // Create channel for request handling
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>(); let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
let mut initial_metrics = ForwardPassMetrics::default();
initial_metrics.worker_stats.data_parallel_rank = dp_rank;
let (metrics_tx, metrics_rx) =
tokio::sync::watch::channel::<ForwardPassMetrics>(initial_metrics);
// Create a clone for the background task // Create a clone for the background task
let state_clone = state.clone(); let state_clone = state.clone();
let kv_manager_clone = kv_manager.clone(); let kv_manager_clone = kv_manager.clone();
let output_tx_clone = output_tx.clone(); let output_tx_clone = output_tx.clone();
let cancel_token_clone = cancellation_token.unwrap_or_default().clone(); let cancel_token_clone = cancellation_token.unwrap_or_default().clone();
let hit_rates_clone = hit_rates.clone();
// Spawn main background task with cancellation token // Spawn main background task with cancellation token
tokio::spawn(async move { tokio::spawn(async move {
...@@ -376,7 +378,7 @@ impl Scheduler { ...@@ -376,7 +378,7 @@ impl Scheduler {
// Compute and store hit rate // Compute and store hit rate
let hit_rate = if !active_sequence.is_empty() { 1.0 - (new_tokens as f32 / active_sequence.len() as f32) } else { 0.0 }; let hit_rate = if !active_sequence.is_empty() { 1.0 - (new_tokens as f32 / active_sequence.len() as f32) } else { 0.0 };
{ {
let mut hit_rates_guard = hit_rates_clone.lock().await; let mut hit_rates_guard = hit_rates.lock().await;
hit_rates_guard.push_back(hit_rate); hit_rates_guard.push_back(hit_rate);
if hit_rates_guard.len() > 1000 { if hit_rates_guard.len() > 1000 {
hit_rates_guard.pop_front(); hit_rates_guard.pop_front();
...@@ -442,6 +444,17 @@ impl Scheduler { ...@@ -442,6 +444,17 @@ impl Scheduler {
state_guard.reset_active_tokens(); state_guard.reset_active_tokens();
{
let hit_rates_guard = hit_rates.lock().await;
let metrics = get_fwd_pass_metrics(
&state_guard,
&kv_manager_guard,
&hit_rates_guard,
dp_rank,
);
let _ = metrics_tx.send(metrics);
}
// Process decoding // Process decoding
let uuids: Vec<Uuid> = state_guard.decode.keys().cloned().collect(); let uuids: Vec<Uuid> = state_guard.decode.keys().cloned().collect();
if !uuids.is_empty() { if !uuids.is_empty() {
...@@ -495,6 +508,17 @@ impl Scheduler { ...@@ -495,6 +508,17 @@ impl Scheduler {
} }
} }
{
let hit_rates_guard = hit_rates.lock().await;
let metrics = get_fwd_pass_metrics(
&state_guard,
&kv_manager_guard,
&hit_rates_guard,
dp_rank,
);
let _ = metrics_tx.send(metrics);
}
if send_failed || is_complete { if send_failed || is_complete {
state_guard.complete(&uuid); state_guard.complete(&uuid);
continue; continue;
...@@ -513,11 +537,10 @@ impl Scheduler { ...@@ -513,11 +537,10 @@ impl Scheduler {
}); });
Self { Self {
dp_rank,
state, state,
kv_manager, kv_manager,
request_tx, request_tx,
hit_rates, metrics_rx,
} }
} }
...@@ -555,56 +578,60 @@ impl Scheduler { ...@@ -555,56 +578,60 @@ impl Scheduler {
kv_manager.current_capacity_perc() kv_manager.current_capacity_perc()
} }
/// Returns forward pass metrics for monitoring purposes /// Get a watch receiver for forward pass metrics
pub async fn get_forward_pass_metrics(&self) -> ForwardPassMetrics { pub fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<ForwardPassMetrics> {
// Acquire all locks in consistent order: state -> kv_manager -> hit_rates self.metrics_rx.clone()
let state = self.state.lock().await; }
let kv_manager = self.kv_manager.lock().await; }
let hit_rates_guard = self.hit_rates.lock().await;
// Get state metrics
let request_active_slots = state.decode.len() as u64;
let num_requests_waiting = state.waiting.len() as u64;
// Get KV manager metrics /// Calculate forward pass metrics from current state
let active_blocks_count = kv_manager.active_blocks().len() as u64; fn get_fwd_pass_metrics(
let total_capacity = kv_manager.max_capacity() as u64; state: &SchedulerState,
let gpu_cache_usage_perc = if total_capacity > 0 { kv_manager: &KvManager,
active_blocks_count as f32 / total_capacity as f32 hit_rates: &VecDeque<f32>,
} else { dp_rank: Option<u32>,
0.0 ) -> ForwardPassMetrics {
}; // Get state metrics
let request_active_slots = state.decode.len() as u64;
let num_requests_waiting = state.waiting.len() as u64;
// Get KV manager metrics
let active_blocks_count = kv_manager.active_blocks().len() as u64;
let total_capacity = kv_manager.max_capacity() as u64;
let gpu_cache_usage_perc = if total_capacity > 0 {
active_blocks_count as f32 / total_capacity as f32
} else {
0.0
};
// Get hit rate metrics // Get hit rate metrics
let gpu_prefix_cache_hit_rate = if hit_rates_guard.is_empty() { let gpu_prefix_cache_hit_rate = if hit_rates.is_empty() {
0.0 0.0
} else { } else {
let sum: f32 = hit_rates_guard.iter().sum(); let sum: f32 = hit_rates.iter().sum();
sum / hit_rates_guard.len() as f32 sum / hit_rates.len() as f32
}; };
let worker_stats = WorkerStats { let worker_stats = WorkerStats {
data_parallel_rank: self.dp_rank, data_parallel_rank: dp_rank,
request_active_slots, request_active_slots,
request_total_slots: 1024, // vllm max_num_seqs for gpu >= 70 vram, otherwise 256, fallback is 128 request_total_slots: 1024, // vllm max_num_seqs for gpu >= 70 vram, otherwise 256, fallback is 128
num_requests_waiting, num_requests_waiting,
}; };
let kv_stats = KvStats { let kv_stats = KvStats {
kv_active_blocks: active_blocks_count, kv_active_blocks: active_blocks_count,
kv_total_blocks: total_capacity, kv_total_blocks: total_capacity,
gpu_cache_usage_perc, gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate, gpu_prefix_cache_hit_rate,
}; };
let spec_decode_stats = None; let spec_decode_stats = None;
ForwardPassMetrics { ForwardPassMetrics {
worker_stats, worker_stats,
kv_stats, kv_stats,
spec_decode_stats, spec_decode_stats,
}
// Guards drop naturally here in reverse order (LIFO): hit_rates_guard, kv_manager, state
} }
} }
...@@ -761,6 +788,9 @@ mod tests { ...@@ -761,6 +788,9 @@ mod tests {
let timeout = tokio::time::sleep(Duration::from_secs(2)); let timeout = tokio::time::sleep(Duration::from_secs(2));
tokio::pin!(timeout); tokio::pin!(timeout);
// Get metrics receiver
let metrics_rx = scheduler.metrics_receiver();
// Set up debug ticker interval // Set up debug ticker interval
let mut debug_interval = interval(Duration::from_millis(500)); let mut debug_interval = interval(Duration::from_millis(500));
...@@ -770,7 +800,7 @@ mod tests { ...@@ -770,7 +800,7 @@ mod tests {
// Manual debug ticker that prints forward pass metrics // Manual debug ticker that prints forward pass metrics
_ = debug_interval.tick() => { _ = debug_interval.tick() => {
let _metrics = scheduler.get_forward_pass_metrics().await; let _metrics = metrics_rx.borrow().clone();
println!("Forward Pass Metrics: {_metrics:#?}"); println!("Forward Pass Metrics: {_metrics:#?}");
} }
...@@ -862,6 +892,9 @@ mod tests { ...@@ -862,6 +892,9 @@ mod tests {
let timeout = tokio::time::sleep(Duration::from_millis(500)); let timeout = tokio::time::sleep(Duration::from_millis(500));
tokio::pin!(timeout); tokio::pin!(timeout);
// Get metrics receiver
let metrics_rx = scheduler.metrics_receiver();
// Set up debug ticker interval // Set up debug ticker interval
let mut debug_interval = interval(Duration::from_millis(500)); let mut debug_interval = interval(Duration::from_millis(500));
...@@ -871,7 +904,7 @@ mod tests { ...@@ -871,7 +904,7 @@ mod tests {
// Manual debug ticker that prints forward pass metrics // Manual debug ticker that prints forward pass metrics
_ = debug_interval.tick() => { _ = debug_interval.tick() => {
let _metrics = scheduler.get_forward_pass_metrics().await; let _metrics = metrics_rx.borrow().clone();
println!("Forward Pass Metrics: {_metrics:#?}"); println!("Forward Pass Metrics: {_metrics:#?}");
} }
...@@ -888,8 +921,11 @@ mod tests { ...@@ -888,8 +921,11 @@ mod tests {
} }
} }
// Wait a bit for final metrics update
tokio::time::sleep(Duration::from_millis(100)).await;
// Verify forward pass metrics // Verify forward pass metrics
let metrics = scheduler.get_forward_pass_metrics().await; let metrics = metrics_rx.borrow().clone();
assert_eq!( assert_eq!(
metrics.worker_stats.num_requests_waiting, 0, metrics.worker_stats.num_requests_waiting, 0,
...@@ -958,7 +994,8 @@ mod tests { ...@@ -958,7 +994,8 @@ mod tests {
tokio::time::sleep(Duration::from_secs(1)).await; tokio::time::sleep(Duration::from_secs(1)).await;
// Check forward pass metrics // Check forward pass metrics
let metrics = scheduler.get_forward_pass_metrics().await; let metrics_rx = scheduler.metrics_receiver();
let metrics = metrics_rx.borrow().clone();
assert_eq!( assert_eq!(
metrics.kv_stats.gpu_cache_usage_perc, metrics.kv_stats.gpu_cache_usage_perc,
......
...@@ -44,6 +44,8 @@ pub struct Client { ...@@ -44,6 +44,8 @@ pub struct Client {
pub instance_source: Arc<InstanceSource>, pub instance_source: Arc<InstanceSource>,
// These are the instance source ids less those reported as down from sending rpc // These are the instance source ids less those reported as down from sending rpc
instance_avail: Arc<ArcSwap<Vec<i64>>>, instance_avail: Arc<ArcSwap<Vec<i64>>>,
// These are the instance source ids less those reported as busy (above threshold)
instance_free: Arc<ArcSwap<Vec<i64>>>,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
...@@ -59,6 +61,7 @@ impl Client { ...@@ -59,6 +61,7 @@ impl Client {
endpoint, endpoint,
instance_source: Arc::new(InstanceSource::Static), instance_source: Arc::new(InstanceSource::Static),
instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))), instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))),
instance_free: Arc::new(ArcSwap::from(Arc::new(vec![]))),
}) })
} }
...@@ -76,8 +79,9 @@ impl Client { ...@@ -76,8 +79,9 @@ impl Client {
let client = Client { let client = Client {
endpoint, endpoint,
instance_source, instance_source: instance_source.clone(),
instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))), instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))),
instance_free: Arc::new(ArcSwap::from(Arc::new(vec![]))),
}; };
client.monitor_instance_source(); client.monitor_instance_source();
Ok(client) Ok(client)
...@@ -108,6 +112,10 @@ impl Client { ...@@ -108,6 +112,10 @@ impl Client {
self.instance_avail.load() self.instance_avail.load()
} }
pub fn instance_ids_free(&self) -> arc_swap::Guard<Arc<Vec<i64>>> {
self.instance_free.load()
}
/// Wait for at least one Instance to be available for this Endpoint /// Wait for at least one Instance to be available for this Endpoint
pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> { pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
let mut instances: Vec<Instance> = vec![]; let mut instances: Vec<Instance> = vec![];
...@@ -142,6 +150,16 @@ impl Client { ...@@ -142,6 +150,16 @@ impl Client {
tracing::debug!("inhibiting instance {instance_id}"); tracing::debug!("inhibiting instance {instance_id}");
} }
/// Update the set of free instances based on busy instance IDs
pub fn update_free_instances(&self, busy_instance_ids: &[i64]) {
let all_instance_ids = self.instance_ids();
let free_ids: Vec<i64> = all_instance_ids
.into_iter()
.filter(|id| !busy_instance_ids.contains(id))
.collect();
self.instance_free.store(Arc::new(free_ids));
}
/// Monitor the ETCD instance source and update instance_avail. /// Monitor the ETCD instance source and update instance_avail.
fn monitor_instance_source(&self) { fn monitor_instance_source(&self) {
let cancel_token = self.endpoint.drt().primary_token(); let cancel_token = self.endpoint.drt().primary_token();
...@@ -160,7 +178,10 @@ impl Client { ...@@ -160,7 +178,10 @@ impl Client {
.iter() .iter()
.map(|instance| instance.id()) .map(|instance| instance.id())
.collect(); .collect();
client.instance_avail.store(Arc::new(instance_ids));
// TODO: this resets both tracked available and free instances
client.instance_avail.store(Arc::new(instance_ids.clone()));
client.instance_free.store(Arc::new(instance_ids));
tracing::debug!("instance source updated"); tracing::debug!("instance source updated");
......
...@@ -131,6 +131,10 @@ pub enum PipelineError { ...@@ -131,6 +131,10 @@ pub enum PipelineError {
#[error("NATS KV Err: {0} for bucket '{1}")] #[error("NATS KV Err: {0} for bucket '{1}")]
KeyValueError(String, String), KeyValueError(String, String),
/// All instances are busy and cannot handle new requests
#[error("Service temporarily unavailable: {0}")]
ServiceOverloaded(String),
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
......
...@@ -2,11 +2,13 @@ ...@@ -2,11 +2,13 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::{AsyncEngineContextProvider, ResponseStream}; use super::{AsyncEngineContextProvider, ResponseStream};
use crate::utils::worker_monitor::WorkerMonitor;
use crate::{ use crate::{
component::{Client, Endpoint, InstanceSource}, component::{Client, Endpoint, InstanceSource},
engine::{AsyncEngine, Data}, engine::{AsyncEngine, Data},
pipeline::{ pipeline::{
error::PipelineErrorExt, AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn, error::{PipelineError, PipelineErrorExt},
AddressedPushRouter, AddressedRequest, Error, ManyOut, SingleIn,
}, },
protocols::maybe_error::MaybeError, protocols::maybe_error::MaybeError,
traits::DistributedRuntimeProvider, traits::DistributedRuntimeProvider,
...@@ -52,6 +54,13 @@ where ...@@ -52,6 +54,13 @@ where
/// addresses it, then passes it to AddressedPushRouter which does the network traffic. /// addresses it, then passes it to AddressedPushRouter which does the network traffic.
addressed: Arc<AddressedPushRouter>, addressed: Arc<AddressedPushRouter>,
/// Worker monitor for tracking KV cache usage
worker_monitor: Option<Arc<WorkerMonitor>>,
/// Threshold for determining when a worker is busy (0.0 to 1.0)
/// If None, busy detection is disabled
busy_threshold: Option<f64>,
/// An internal Rust type. This says that PushRouter is generic over the T and U types, /// An internal Rust type. This says that PushRouter is generic over the T and U types,
/// which are the input and output types of it's `generate` function. It allows the /// which are the input and output types of it's `generate` function. It allows the
/// compiler to specialize us at compile time. /// compiler to specialize us at compile time.
...@@ -86,15 +95,43 @@ where ...@@ -86,15 +95,43 @@ where
T: Data + Serialize, T: Data + Serialize,
U: Data + for<'de> Deserialize<'de> + MaybeError, U: Data + for<'de> Deserialize<'de> + MaybeError,
{ {
/// Create a new PushRouter without busy threshold (no busy detection)
pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> { pub async fn from_client(client: Client, router_mode: RouterMode) -> anyhow::Result<Self> {
Self::from_client_with_threshold(client, router_mode, None).await
}
/// Create a new PushRouter with optional busy threshold
pub async fn from_client_with_threshold(
client: Client,
router_mode: RouterMode,
busy_threshold: Option<f64>,
) -> anyhow::Result<Self> {
let addressed = addressed_router(&client.endpoint).await?; let addressed = addressed_router(&client.endpoint).await?;
Ok(PushRouter {
client, // Create worker monitor only if we have a threshold and are in dynamic mode
let worker_monitor = match (busy_threshold, client.instance_source.as_ref()) {
(Some(threshold), InstanceSource::Dynamic(_)) => {
let monitor = Arc::new(WorkerMonitor::new_with_threshold(
Arc::new(client.clone()),
threshold,
));
monitor.start_monitoring().await?;
Some(monitor)
}
_ => None,
};
let router = PushRouter {
client: client.clone(),
addressed, addressed,
router_mode, router_mode,
round_robin_counter: Arc::new(AtomicU64::new(0)), round_robin_counter: Arc::new(AtomicU64::new(0)),
worker_monitor,
busy_threshold,
_phantom: PhantomData, _phantom: PhantomData,
}) };
Ok(router)
} }
/// Issue a request to the next available instance in a round-robin fashion /// Issue a request to the next available instance in a round-robin fashion
...@@ -170,6 +207,21 @@ where ...@@ -170,6 +207,21 @@ where
instance_id: i64, instance_id: i64,
request: SingleIn<T>, request: SingleIn<T>,
) -> anyhow::Result<ManyOut<U>> { ) -> anyhow::Result<ManyOut<U>> {
// Check if all workers are busy (only if busy threshold is set)
if self.busy_threshold.is_some() {
let free_instances = self.client.instance_ids_free();
if free_instances.is_empty() {
// Check if we actually have any instances at all
let all_instances = self.client.instance_ids();
if !all_instances.is_empty() {
return Err(PipelineError::ServiceOverloaded(
"All workers are busy, please retry later".to_string(),
)
.into());
}
}
}
let subject = self.client.endpoint.subject_to(instance_id); let subject = self.client.endpoint.subject_to(instance_id);
let request = request.map(|req| AddressedRequest::new(req, subject)); let request = request.map(|req| AddressedRequest::new(req, subject));
......
...@@ -19,3 +19,5 @@ pub mod leader_worker_barrier; ...@@ -19,3 +19,5 @@ pub mod leader_worker_barrier;
pub mod pool; pub mod pool;
pub mod stream; pub mod stream;
pub mod task; pub mod task;
pub mod typed_prefix_watcher;
pub mod worker_monitor;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Generic etcd watcher utilities for maintaining collated state from etcd prefixes.
//!
//! This module provides reusable patterns for watching etcd prefixes and maintaining
//! HashMap-based state that automatically updates based on etcd events.
use crate::transports::etcd::{Client as EtcdClient, WatchEvent};
use crate::Result;
use etcd_client::KeyValue;
use serde::de::DeserializeOwned;
use std::collections::HashMap;
use std::fmt::Debug;
use tokio::sync::watch;
use tokio_util::sync::CancellationToken;
/// A generic etcd prefix watcher that maintains a HashMap of deserialized values.
///
/// This struct watches an etcd prefix and maintains a HashMap where:
/// - Keys are extracted from the etcd KeyValue (e.g., lease_id, key string, etc.)
/// - Values are extracted from the deserialized type using a value extractor
///
/// # Type Parameters
/// - `K`: The key type for the HashMap (must be hashable)
/// - `V`: The value type stored in the HashMap
pub struct TypedPrefixWatcher<K, V>
where
K: Clone + Eq + std::hash::Hash + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
rx: watch::Receiver<HashMap<K, V>>,
}
impl<K, V> TypedPrefixWatcher<K, V>
where
K: Clone + Eq + std::hash::Hash + Send + Sync + Debug + 'static,
V: Clone + Send + Sync + 'static,
{
/// Get a receiver for the current state
pub fn receiver(&self) -> watch::Receiver<HashMap<K, V>> {
self.rx.clone()
}
/// Get the current state
pub fn current(&self) -> HashMap<K, V> {
self.rx.borrow().clone()
}
}
/// Watch an etcd prefix and maintain a HashMap of values with field extraction
///
/// This function watches an etcd prefix and maintains a HashMap where values are
/// extracted from a deserialized type using a value extractor function.
///
/// # Type Parameters
/// - `K`: The key type for the HashMap
/// - `V`: The value type stored in the HashMap
/// - `T`: The type to deserialize from etcd
///
/// # Arguments
/// - `client`: The etcd client to use
/// - `prefix`: The prefix to watch in etcd
/// - `key_extractor`: Function to extract the key from a KeyValue
/// - `value_extractor`: Function to extract the value from the deserialized type
/// - `cancellation_token`: Token to stop the watcher
///
/// # Example
/// ```ignore
/// // Watch for ModelEntry objects and extract runtime_config field
/// let watcher = watch_prefix_with_extraction(
/// etcd_client,
/// "models/",
/// |kv| Some(kv.lease()), // Use lease_id as key
/// |entry: ModelEntry| entry.runtime_config, // Extract runtime_config field
/// cancellation_token,
/// ).await?;
/// ```
pub async fn watch_prefix_with_extraction<K, V, T>(
client: EtcdClient,
prefix: impl Into<String>,
key_extractor: impl Fn(&KeyValue) -> Option<K> + Send + 'static,
value_extractor: impl Fn(T) -> Option<V> + Send + 'static,
cancellation_token: CancellationToken,
) -> Result<TypedPrefixWatcher<K, V>>
where
K: Clone + Eq + std::hash::Hash + Send + Sync + Debug + 'static,
V: Clone + Send + Sync + 'static,
T: DeserializeOwned + Send + 'static,
{
let (watch_tx, watch_rx) = watch::channel(HashMap::new());
let prefix = prefix.into();
let prefix_watcher = client.kv_get_and_watch_prefix(&prefix).await?;
let (prefix_str, _watcher, mut events_rx) = prefix_watcher.dissolve();
tokio::spawn(async move {
let mut state: HashMap<K, V> = HashMap::new();
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::debug!("TypedPrefixWatcher for prefix '{}' cancelled", prefix_str);
break;
}
event = events_rx.recv() => {
let Some(event) = event else {
tracing::debug!("TypedPrefixWatcher watch stream closed for prefix '{}'", prefix_str);
break;
};
match event {
WatchEvent::Put(kv) => {
// Extract the key
let Some(key) = key_extractor(&kv) else {
tracing::trace!("Skipping entry - key extractor returned None");
continue;
};
// Deserialize the value
let deserialized = match serde_json::from_slice::<T>(kv.value()) {
Ok(val) => val,
Err(e) => {
tracing::warn!(
"Failed to deserialize value from etcd. Key: {}, Error: {}",
kv.key_str().unwrap_or("<invalid>"),
e
);
continue;
}
};
// Extract the value
match value_extractor(deserialized) {
Some(v) => {
state.insert(key.clone(), v);
tracing::trace!("Updated entry for key {:?}", key);
}
None => {
state.remove(&key);
tracing::trace!("Removed entry for key {:?} (extractor returned None)", key);
}
}
if watch_tx.send(state.clone()).is_err() {
tracing::error!("Failed to send update; receiver dropped");
break;
}
}
WatchEvent::Delete(kv) => {
if let Some(key) = key_extractor(&kv) {
state.remove(&key);
tracing::trace!("Removed entry for deleted key {:?}", key);
if watch_tx.send(state.clone()).is_err() {
tracing::error!("Failed to send update; receiver dropped");
break;
}
}
}
}
}
}
}
tracing::info!("TypedPrefixWatcher for prefix '{}' stopped", prefix_str);
});
Ok(TypedPrefixWatcher { rx: watch_rx })
}
/// Watch an etcd prefix and maintain a HashMap of values without field extraction
///
/// This is a simpler version when you want to store the entire deserialized value.
///
/// # Example
/// ```ignore
/// // Watch for TestConfig objects directly
/// let watcher = watch_prefix(
/// etcd_client,
/// "configs/",
/// |kv| Some(kv.lease()), // Use lease_id as key
/// cancellation_token,
/// ).await?;
/// ```
pub async fn watch_prefix<K, V>(
client: EtcdClient,
prefix: impl Into<String>,
key_extractor: impl Fn(&KeyValue) -> Option<K> + Send + 'static,
cancellation_token: CancellationToken,
) -> Result<TypedPrefixWatcher<K, V>>
where
K: Clone + Eq + std::hash::Hash + Send + Sync + Debug + 'static,
V: Clone + DeserializeOwned + Send + Sync + 'static,
{
watch_prefix_with_extraction(
client,
prefix,
key_extractor,
|v: V| Some(v), // Identity function - just return the value
cancellation_token,
)
.await
}
/// Common key extractors for convenience
pub mod key_extractors {
use etcd_client::KeyValue;
/// Extract the lease ID as the key
pub fn lease_id(kv: &KeyValue) -> Option<i64> {
Some(kv.lease())
}
/// Extract the key as a string (without prefix)
pub fn key_string(prefix: &str) -> impl Fn(&KeyValue) -> Option<String> {
let prefix = prefix.to_string();
move |kv: &KeyValue| {
kv.key_str()
.ok()
.map(|k| k.strip_prefix(&prefix).unwrap_or(k).to_string())
}
}
/// Extract the full key as a string
pub fn full_key_string(kv: &KeyValue) -> Option<String> {
kv.key_str().ok().map(|s| s.to_string())
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
// TODO: Make load comparisons and runtime metrics a generic trait so this monitoring
// system is not tied to KV cache concepts, which are LLM-specific. This would allow
// different types of workers to define their own load metrics and busy thresholds.
use crate::component::{Client, InstanceSource};
use crate::traits::events::EventSubscriber;
use crate::traits::DistributedRuntimeProvider;
use crate::utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tokio::sync::watch;
use tokio_stream::StreamExt;
// Constants for monitoring configuration
const KV_METRICS_SUBJECT: &str = "kv_metrics";
const MODEL_ROOT_PATH: &str = "models";
// Internal structs for deserializing metrics events
#[derive(serde::Deserialize)]
struct LoadEvent {
worker_id: i64,
data: ForwardPassMetrics,
}
#[derive(serde::Deserialize)]
struct ForwardPassMetrics {
kv_stats: KvStats,
}
#[derive(serde::Deserialize)]
struct KvStats {
kv_active_blocks: u64,
}
#[derive(serde::Deserialize)]
struct ModelEntry {
runtime_config: Option<RuntimeConfig>,
}
#[derive(serde::Deserialize)]
struct RuntimeConfig {
total_kv_blocks: Option<u64>,
}
/// Worker load monitoring state
#[derive(Clone, Debug)]
pub struct WorkerLoadState {
pub kv_active_blocks: Option<u64>,
pub kv_total_blocks: Option<u64>,
}
impl WorkerLoadState {
pub fn is_busy(&self, threshold: f64) -> bool {
match (self.kv_active_blocks, self.kv_total_blocks) {
(Some(active), Some(total)) if total > 0 => {
(active as f64) > (threshold * total as f64)
}
_ => false,
}
}
}
/// Worker monitor for tracking KV cache usage and busy states
pub struct WorkerMonitor {
client: Arc<Client>,
worker_load_states: Arc<RwLock<HashMap<i64, WorkerLoadState>>>,
busy_threshold: f64,
}
impl WorkerMonitor {
/// Create a new worker monitor with custom threshold
pub fn new_with_threshold(client: Arc<Client>, busy_threshold: f64) -> Self {
Self {
client,
worker_load_states: Arc::new(RwLock::new(HashMap::new())),
busy_threshold,
}
}
/// Get the worker load states for external access
pub fn load_states(&self) -> Arc<RwLock<HashMap<i64, WorkerLoadState>>> {
self.worker_load_states.clone()
}
/// Start background monitoring of worker KV cache usage
pub async fn start_monitoring(&self) -> anyhow::Result<()> {
let endpoint = &self.client.endpoint;
let component = endpoint.component();
let Some(etcd_client) = component.drt().etcd_client() else {
// Static mode, no monitoring needed
return Ok(());
};
let runtime_configs_watcher = watch_prefix_with_extraction(
etcd_client,
MODEL_ROOT_PATH,
key_extractors::lease_id,
|entry: ModelEntry| entry.runtime_config.and_then(|rc| rc.total_kv_blocks),
component.drt().child_token(),
)
.await?;
let mut config_events_rx = runtime_configs_watcher.receiver();
// Subscribe to KV metrics events
let mut kv_metrics_rx = component.namespace().subscribe(KV_METRICS_SUBJECT).await?;
let worker_load_states = self.worker_load_states.clone();
let client = self.client.clone();
let cancellation_token = component.drt().child_token();
let busy_threshold = self.busy_threshold; // Capture threshold for the closure
// Spawn background monitoring task
tokio::spawn(async move {
let mut previous_busy_instances = Vec::new(); // Track previous state
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::debug!("Worker monitoring cancelled");
break;
}
// Handle runtime config updates - now receives full HashMap
_ = config_events_rx.changed() => {
let runtime_configs = config_events_rx.borrow().clone();
let mut states = worker_load_states.write().unwrap();
states.retain(|lease_id, _| runtime_configs.contains_key(lease_id));
// Update worker load states with total blocks
for (lease_id, total_blocks) in runtime_configs.iter() {
let state = states.entry(*lease_id).or_insert(WorkerLoadState {
kv_active_blocks: None,
kv_total_blocks: None,
});
state.kv_total_blocks = Some(*total_blocks);
}
}
// Handle KV metrics updates
kv_event = kv_metrics_rx.next() => {
let Some(event) = kv_event else {
tracing::debug!("KV metrics stream closed");
break;
};
if let Ok(load_event) = serde_json::from_slice::<LoadEvent>(&event.payload) {
let worker_id = load_event.worker_id;
let active_blocks = load_event.data.kv_stats.kv_active_blocks;
// Update worker load state
let mut states = worker_load_states.write().unwrap();
let state = states.entry(worker_id).or_insert(WorkerLoadState {
kv_active_blocks: None,
kv_total_blocks: None,
});
state.kv_active_blocks = Some(active_blocks);
drop(states);
// Recalculate all busy instances and update
let states = worker_load_states.read().unwrap();
let busy_instances: Vec<i64> = states
.iter()
.filter_map(|(&id, state)| {
state.is_busy(busy_threshold).then_some(id)
})
.collect();
drop(states);
// Only update if busy_instances has changed
if busy_instances != previous_busy_instances {
tracing::debug!("Busy instances changed: {:?}", busy_instances);
client.update_free_instances(&busy_instances);
previous_busy_instances = busy_instances;
}
}
}
}
}
tracing::info!("Worker monitoring task exiting");
});
Ok(())
}
}
...@@ -5,6 +5,8 @@ import asyncio ...@@ -5,6 +5,8 @@ import asyncio
import json import json
import logging import logging
import os import os
import random
from typing import Any, Dict
import aiohttp import aiohttp
import pytest import pytest
...@@ -22,6 +24,19 @@ SPEEDUP_RATIO = 10.0 ...@@ -22,6 +24,19 @@ SPEEDUP_RATIO = 10.0
NUM_REQUESTS = 100 NUM_REQUESTS = 100
PORT = 8090 # Starting port for mocker instances PORT = 8090 # Starting port for mocker instances
# Shared test payload for all tests
TEST_PAYLOAD: Dict[str, Any] = {
"model": MODEL_NAME,
"messages": [
{
"role": "user",
"content": "In a quiet meadow tucked between rolling hills, a plump gray rabbit nibbled on clover beneath the shade of a gnarled oak tree. Its ears twitched at the faint rustle of leaves, but it remained calm, confident in the safety of its burrow just a few hops away. The late afternoon sun warmed its fur, and tiny dust motes danced in the golden light as bees hummed lazily nearby. Though the rabbit lived a simple life, every day was an adventure of scents, shadows, and snacks—an endless search for the tastiest patch of greens and the softest spot to nap.",
}
],
"stream": True,
"max_tokens": 10,
}
class MockerProcess(ManagedProcess): class MockerProcess(ManagedProcess):
"""Manages a single mocker engine instance""" """Manages a single mocker engine instance"""
...@@ -88,6 +103,89 @@ class KVRouterProcess(ManagedProcess): ...@@ -88,6 +103,89 @@ class KVRouterProcess(ManagedProcess):
super().__exit__(exc_type, exc_val, exc_tb) super().__exit__(exc_type, exc_val, exc_tb)
async def send_request_with_retry(url: str, payload: dict, max_retries: int = 4):
"""Send a single request with exponential backoff retry"""
wait_time = 1 # Start with 1 second
for attempt in range(max_retries + 1):
await asyncio.sleep(wait_time)
try:
async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload) as response:
if response.status == 200:
# Read the response to ensure it's valid
async for _ in response.content:
pass
logger.info(f"First request succeeded on attempt {attempt + 1}")
return True
else:
logger.warning(
f"Attempt {attempt + 1} failed with status {response.status}"
)
except Exception as e:
logger.warning(f"Attempt {attempt + 1} failed with error: {e}")
if attempt < max_retries:
wait_time *= 2 # Double the wait time
return False
async def send_concurrent_requests(urls: list, payload: dict, num_requests: int):
"""Send multiple requests concurrently, alternating between URLs if multiple provided"""
# First, send test requests with retry to ensure all systems are ready
for i, url in enumerate(urls):
logger.info(f"Sending initial test request to URL {i} ({url}) with retry...")
if not await send_request_with_retry(url, payload):
raise RuntimeError(f"Failed to connect to URL {i} after multiple retries")
async def send_single_request(session: aiohttp.ClientSession, request_id: int):
# Alternate between URLs based on request_id
url = urls[request_id % len(urls)]
url_index = request_id % len(urls)
try:
async with session.post(url, json=payload) as response:
if response.status != 200:
logger.error(
f"Request {request_id} to URL {url_index} failed with status {response.status}"
)
return False
# For streaming responses, read the entire stream
chunks = []
async for line in response.content:
if line:
chunks.append(line)
logger.debug(
f"Request {request_id} to URL {url_index} completed with {len(chunks)} chunks"
)
return True
except Exception as e:
logger.error(
f"Request {request_id} to URL {url_index} failed with error: {e}"
)
return False
# Send all requests at once
async with aiohttp.ClientSession() as session:
tasks = [send_single_request(session, i) for i in range(num_requests)]
results = await asyncio.gather(*tasks)
successful = sum(1 for r in results if r)
failed = sum(1 for r in results if not r)
logger.info(f"Completed all requests: {successful} successful, {failed} failed")
assert (
successful == num_requests
), f"Expected {num_requests} successful requests, got {successful}"
logger.info(f"All {num_requests} requests completed successfully")
@pytest.mark.pre_merge @pytest.mark.pre_merge
def test_mocker_kv_router(request, runtime_services): def test_mocker_kv_router(request, runtime_services):
""" """
...@@ -128,26 +226,13 @@ def test_mocker_kv_router(request, runtime_services): ...@@ -128,26 +226,13 @@ def test_mocker_kv_router(request, runtime_services):
for mocker in mocker_processes: for mocker in mocker_processes:
mocker.__enter__() mocker.__enter__()
# Send test requests
test_payload = {
"model": MODEL_NAME,
"messages": [
{
"role": "user",
"content": "In a quiet meadow tucked between rolling hills, a plump gray rabbit nibbled on clover beneath the shade of a gnarled oak tree. Its ears twitched at the faint rustle of leaves, but it remained calm, confident in the safety of its burrow just a few hops away. The late afternoon sun warmed its fur, and tiny dust motes danced in the golden light as bees hummed lazily nearby. Though the rabbit lived a simple life, every day was an adventure of scents, shadows, and snacks—an endless search for the tastiest patch of greens and the softest spot to nap.",
}
],
"stream": True,
"max_tokens": 10,
}
# Use async to send requests concurrently for better performance # Use async to send requests concurrently for better performance
asyncio.run( asyncio.run(
send_concurrent_requests( send_concurrent_requests(
[ [
f"http://localhost:{frontend_port}/v1/chat/completions" f"http://localhost:{frontend_port}/v1/chat/completions"
], # Pass as list ], # Pass as list
test_payload, TEST_PAYLOAD,
NUM_REQUESTS, NUM_REQUESTS,
) )
) )
...@@ -209,19 +294,6 @@ def test_mocker_two_kv_router(request, runtime_services): ...@@ -209,19 +294,6 @@ def test_mocker_two_kv_router(request, runtime_services):
for mocker in mocker_processes: for mocker in mocker_processes:
mocker.__enter__() mocker.__enter__()
# Send test requests
test_payload = {
"model": MODEL_NAME,
"messages": [
{
"role": "user",
"content": "In a quiet meadow tucked between rolling hills, a plump gray rabbit nibbled on clover beneath the shade of a gnarled oak tree. Its ears twitched at the faint rustle of leaves, but it remained calm, confident in the safety of its burrow just a few hops away. The late afternoon sun warmed its fur, and tiny dust motes danced in the golden light as bees hummed lazily nearby. Though the rabbit lived a simple life, every day was an adventure of scents, shadows, and snacks—an endless search for the tastiest patch of greens and the softest spot to nap.",
}
],
"stream": True,
"max_tokens": 10,
}
# Build URLs for both routers # Build URLs for both routers
router_urls = [ router_urls = [
f"http://localhost:{port}/v1/chat/completions" for port in router_ports f"http://localhost:{port}/v1/chat/completions" for port in router_ports
...@@ -231,7 +303,7 @@ def test_mocker_two_kv_router(request, runtime_services): ...@@ -231,7 +303,7 @@ def test_mocker_two_kv_router(request, runtime_services):
asyncio.run( asyncio.run(
send_concurrent_requests( send_concurrent_requests(
router_urls, router_urls,
test_payload, TEST_PAYLOAD,
NUM_REQUESTS, NUM_REQUESTS,
) )
) )
...@@ -253,84 +325,177 @@ def test_mocker_two_kv_router(request, runtime_services): ...@@ -253,84 +325,177 @@ def test_mocker_two_kv_router(request, runtime_services):
os.unlink(mocker_args_file) os.unlink(mocker_args_file)
async def send_request_with_retry(url: str, payload: dict, max_retries: int = 4): @pytest.mark.pre_merge
"""Send a single request with exponential backoff retry""" @pytest.mark.skip(reason="Flaky, temporarily disabled")
wait_time = 1 # Start with 1 second def test_mocker_kv_router_overload_503(request, runtime_services):
"""
Test that KV router returns 503 when all workers are busy.
This test uses limited resources to intentionally trigger the overload condition.
"""
for attempt in range(max_retries + 1): # runtime_services starts etcd and nats
await asyncio.sleep(wait_time) logger.info("Starting mocker KV router overload test for 503 status")
try:
async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload) as response:
if response.status == 200:
# Read the response to ensure it's valid
async for _ in response.content:
pass
logger.info(f"First request succeeded on attempt {attempt + 1}")
return True
else:
logger.warning(
f"Attempt {attempt + 1} failed with status {response.status}"
)
except Exception as e:
logger.warning(f"Attempt {attempt + 1} failed with error: {e}")
if attempt < max_retries: # Create mocker args file with limited resources
wait_time *= 2 # Double the wait time mocker_args = {
"speedup_ratio": 10,
"block_size": 4, # Smaller block size
"num_gpu_blocks": 64, # Limited GPU blocks to exhaust quickly
}
return False mocker_args_file = os.path.join(request.node.name, "mocker_args_overload.json")
with open(mocker_args_file, "w") as f:
json.dump(mocker_args, f)
try:
# Start KV router (frontend) with limited block size
frontend_port = PORT + 10 # Use different port to avoid conflicts
logger.info(
f"Starting KV router frontend on port {frontend_port} with limited resources"
)
async def send_concurrent_requests(urls: list, payload: dict, num_requests: int): # Custom command for router with limited block size
"""Send multiple requests concurrently, alternating between URLs if multiple provided""" command = [
"python",
"-m",
"dynamo.frontend",
"--busy-threshold",
"0.2",
"--kv-cache-block-size",
"4", # Match the mocker's block size
"--router-mode",
"kv",
"--http-port",
str(frontend_port),
]
# First, send test requests with retry to ensure all systems are ready kv_router = ManagedProcess(
for i, url in enumerate(urls): command=command,
logger.info(f"Sending initial test request to URL {i} ({url}) with retry...") timeout=60,
if not await send_request_with_retry(url, payload): display_output=True,
raise RuntimeError(f"Failed to connect to URL {i} after multiple retries") health_check_ports=[frontend_port],
health_check_urls=[
(
f"http://localhost:{frontend_port}/v1/models",
lambda r: r.status_code == 200,
)
],
log_dir=request.node.name,
terminate_existing=False,
)
kv_router.__enter__()
async def send_single_request(session: aiohttp.ClientSession, request_id: int): # Start single mocker instance with limited resources
# Alternate between URLs based on request_id endpoint = "dyn://test-namespace.mocker.generate"
url = urls[request_id % len(urls)] logger.info(
url_index = request_id % len(urls) f"Starting single mocker instance with limited resources on endpoint {endpoint}"
)
try: mocker = MockerProcess(request, endpoint, mocker_args_file)
async with session.post(url, json=payload) as response: mocker.__enter__()
if response.status != 200:
logger.error(
f"Request {request_id} to URL {url_index} failed with status {response.status}"
)
return False
# For streaming responses, read the entire stream url = f"http://localhost:{frontend_port}/v1/chat/completions"
chunks = []
async for line in response.content:
if line:
chunks.append(line)
logger.debug( # Custom payload for 503 test with more tokens to consume resources
f"Request {request_id} to URL {url_index} completed with {len(chunks)} chunks" test_payload_503 = {
) **TEST_PAYLOAD,
return True "max_tokens": 50, # Longer output to consume more blocks
}
except Exception as e: # First, send one request with retry to ensure system is ready
logger.error( logger.info("Sending initial request to ensure system is ready...")
f"Request {request_id} to URL {url_index} failed with error: {e}" asyncio.run(send_concurrent_requests([url], test_payload_503, 1))
)
return False
# Send all requests at once # Now send 50 concurrent requests to exhaust resources, then verify 503
async with aiohttp.ClientSession() as session: logger.info("Sending 50 concurrent requests to exhaust resources...")
tasks = [send_single_request(session, i) for i in range(num_requests)]
results = await asyncio.gather(*tasks)
successful = sum(1 for r in results if r) async def exhaust_resources_and_verify_503():
failed = sum(1 for r in results if not r) async with aiohttp.ClientSession() as session:
# Start 50 long-running requests concurrently
tasks = []
for i in range(50):
# Create unique shuffled content for each request
content_words = TEST_PAYLOAD["messages"][0]["content"].split()
random.shuffle(content_words)
shuffled_content = " ".join(content_words)
# Create unique payload for this request
unique_payload = {
**TEST_PAYLOAD,
"max_tokens": 50,
"messages": [
{**TEST_PAYLOAD["messages"][0], "content": shuffled_content}
],
}
async def send_long_request(req_id, payload):
try:
async with session.post(url, json=payload) as response:
if response.status == 200:
# Don't read the response fully, just hold the connection
await asyncio.sleep(
10
) # Hold connection for 10 seconds
return True
else:
logger.info(
f"Request {req_id} got status {response.status}"
)
return False
except Exception as e:
logger.info(f"Request {req_id} failed: {e}")
return False
tasks.append(
asyncio.create_task(send_long_request(i, unique_payload))
)
logger.info(f"Completed all requests: {successful} successful, {failed} failed") # Wait briefly to ensure requests are in-flight
await asyncio.sleep(0.2)
# Now send one more request that should get 503
logger.info("Sending additional request that should receive 503...")
try:
async with session.post(url, json=test_payload_503) as response:
status_code = response.status
if status_code == 503:
body = await response.json()
logger.info(f"Got expected 503 response: {body}")
assert "Service temporarily unavailable" in body.get(
"error", ""
) or "All workers are busy" in body.get(
"error", ""
), f"Expected service overload error message, got: {body}"
return True
else:
logger.error(f"Expected 503 but got {status_code}")
if status_code == 200:
logger.error(
"Request unexpectedly succeeded when it should have been rejected"
)
return False
except Exception as e:
logger.error(f"Failed to send overload test request: {e}")
return False
finally:
# Cancel all background tasks
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
assert ( # Run the test
successful == num_requests success = asyncio.run(exhaust_resources_and_verify_503())
), f"Expected {num_requests} successful requests, got {successful}" assert success, "Failed to verify 503 response when resources are exhausted"
logger.info(f"All {num_requests} requests completed successfully")
logger.info("Successfully verified 503 response when all workers are busy")
finally:
# Clean up
if "kv_router" in locals():
kv_router.__exit__(None, None, None)
if "mocker" in locals():
mocker.__exit__(None, None, None)
if os.path.exists(mocker_args_file):
os.unlink(mocker_args_file)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment