Commit 11a36651 authored by Alec's avatar Alec Committed by GitHub
Browse files

[fix] KV Router Example fixes (#314)


Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent d38325c2
...@@ -67,7 +67,7 @@ class Router: ...@@ -67,7 +67,7 @@ class Router:
# catch specific router exceptions once we have dedicated types. # catch specific router exceptions once we have dedicated types.
except Exception as e: except Exception as e:
vllm_logger.info(f"{e}") vllm_logger.info(f"{e}")
worker_id = None worker_id = ""
vllm_logger.exception(f"Error during worker selection: {e}") vllm_logger.exception(f"Error during worker selection: {e}")
vllm_logger.info(f"Scheduling to worker_id: {worker_id}") vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
......
...@@ -82,9 +82,8 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): ...@@ -82,9 +82,8 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
""" """
Serve the triton-init.vllm.generate endpoint. Serve the triton-init.vllm.generate endpoint.
""" """
metrics_publisher = KvMetricsPublisher()
worker_component = runtime.namespace("triton-init").component("vllm") worker_component = runtime.namespace("triton-init").component("vllm")
await metrics_publisher.create_service(worker_component) await worker_component.create_service()
worker_endpoint = worker_component.endpoint("generate") worker_endpoint = worker_component.endpoint("generate")
...@@ -98,6 +97,7 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): ...@@ -98,6 +97,7 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
VLLM_KV_COMPONENT = "vllm" VLLM_KV_COMPONENT = "vllm"
os.environ["VLLM_KV_COMPONENT"] = str(VLLM_KV_COMPONENT) os.environ["VLLM_KV_COMPONENT"] = str(VLLM_KV_COMPONENT)
metrics_publisher = KvMetricsPublisher()
vllm_engine = VllmEngine(engine_args, metrics_publisher) vllm_engine = VllmEngine(engine_args, metrics_publisher)
await vllm_engine.initialize() await vllm_engine.initialize()
# Initially send dummy metrics to kick start, # Initially send dummy metrics to kick start,
...@@ -109,7 +109,10 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): ...@@ -109,7 +109,10 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
1024, 1024,
) )
await worker_endpoint.serve_endpoint(vllm_engine.generate) await asyncio.gather(
worker_endpoint.serve_endpoint(vllm_engine.generate),
metrics_publisher.create_endpoint(worker_component),
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -69,7 +69,7 @@ impl KvMetricsPublisher { ...@@ -69,7 +69,7 @@ impl KvMetricsPublisher {
}) })
} }
fn create_service<'p>( fn create_endpoint<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
component: Component, component: Component,
......
...@@ -17,7 +17,7 @@ use anyhow::Result; ...@@ -17,7 +17,7 @@ use anyhow::Result;
use futures::stream::StreamExt; use futures::stream::StreamExt;
use std::{sync::Arc, time::Duration}; use std::{sync::Arc, time::Duration};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing as log; use tracing;
use triton_distributed_runtime::{component::Component, DistributedRuntime}; use triton_distributed_runtime::{component::Component, DistributedRuntime};
pub mod indexer; pub mod indexer;
...@@ -56,8 +56,8 @@ impl KvRouter { ...@@ -56,8 +56,8 @@ impl KvRouter {
let nats_client = runtime.nats_client(); let nats_client = runtime.nats_client();
let service_name = backend.service_name(); let service_name = backend.service_name();
let kv_subject = backend.event_subject(KV_EVENT_SUBJECT); let kv_subject = backend.event_subject(KV_EVENT_SUBJECT);
log::info!("Component Service Name {}", service_name); tracing::info!("Component Service Name {}", service_name);
log::info!("KV Subject {}", kv_subject); tracing::info!("KV Subject {}", kv_subject);
Self::new(nats_client, service_name, kv_subject).await Self::new(nats_client, service_name, kv_subject).await
} }
...@@ -79,16 +79,26 @@ impl KvRouter { ...@@ -79,16 +79,26 @@ impl KvRouter {
let indexer = KvIndexer::new(cancellation_token.clone()); let indexer = KvIndexer::new(cancellation_token.clone());
let scheduler = KvScheduler::start(ep_rx).await?; let scheduler = KvScheduler::start(ep_rx).await?;
log::debug!("subscribing to kv events: {}", kv_subject); tracing::debug!("subscribing to kv events: {}", kv_subject);
let mut kv_events_rx = nats_client.client().subscribe(kv_subject).await?; let mut kv_events_rx = nats_client.client().subscribe(kv_subject).await?;
let kv_events_tx = indexer.event_sender(); let kv_events_tx = indexer.event_sender();
tokio::spawn(async move { tokio::spawn(async move {
while let Some(event) = kv_events_rx.next().await { while let Some(event) = kv_events_rx.next().await {
let event: RouterEvent = serde_json::from_slice(&event.payload).unwrap(); let event: RouterEvent = match serde_json::from_slice(&event.payload) {
log::debug!("received kv event: {:?}", event); Ok(event) => {
tracing::debug!("received kv event: {:?}", event);
event
}
Err(e) => {
tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
// Choosing warn and continue to process other events from other workers
// A bad event likely signals a problem with a worker, but potentially other workers are still healthy
continue;
}
};
if let Err(e) = kv_events_tx.send(event).await { if let Err(e) = kv_events_tx.send(event).await {
log::trace!("failed to send kv event to indexer; shutting down: {:?}", e); tracing::trace!("failed to send kv event to indexer; shutting down: {:?}", e);
} }
} }
}); });
...@@ -118,7 +128,7 @@ impl KvRouter { ...@@ -118,7 +128,7 @@ impl KvRouter {
.indexer .indexer
.find_matches_for_request(token_ids.as_slice()) .find_matches_for_request(token_ids.as_slice())
.await?; .await?;
log::debug!("KV router overlap_scores: {:?}", overlap_scores); tracing::debug!("KV router overlap_scores: {:?}", overlap_scores);
let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?; let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
Ok(worker_id) Ok(worker_id)
} }
...@@ -133,36 +143,52 @@ async fn collect_endpoints( ...@@ -133,36 +143,52 @@ async fn collect_endpoints(
loop { loop {
tokio::select! { tokio::select! {
_ = cancel.cancelled() => { _ = cancel.cancelled() => {
log::debug!("cancellation token triggered"); tracing::debug!("cancellation token triggered");
break; break;
} }
_ = tokio::time::sleep(Duration::from_secs(1)) => { _ = tokio::time::sleep(Duration::from_secs(1)) => {
log::trace!("collecting endpoints for service: {}", service_name); tracing::trace!("collecting endpoints for service: {}", service_name);
} }
} }
let values = nats_client let values = match nats_client
.get_endpoints(&service_name, Duration::from_secs(1)) .get_endpoints(&service_name, Duration::from_secs(1))
.await .await
.unwrap(); {
Ok(v) => v,
Err(e) => {
tracing::warn!("Failed to retrieve endpoints for {}: {:?}", service_name, e);
continue;
}
};
// [FIXME] Endpoint is parsed from nats stats handler which may not include 'data' field tracing::debug!("values: {:?}", values);
// if the service hasn't registered the handler. Need to be tolerant to this.
// Another option is to make sure the router is configured properly that
// it listens to the right subject (where other publisher has stats).
let services: Vec<Service> = values let services: Vec<Service> = values
.into_iter() .into_iter()
.filter(|v| !v.is_empty()) .filter(|v| !v.is_empty())
.map(|v| { .filter_map(|v| match serde_json::from_slice::<Service>(&v) {
let value: serde_json::Value = serde_json::from_slice(&v).unwrap(); Ok(service) => Some(service),
log::trace!("service value: {:?}", value); Err(e) => {
serde_json::from_slice(&v).unwrap() tracing::warn!("For value: {:?} \nFailed to parse service: {:?}", v, e);
None
}
}) })
.collect(); .collect();
tracing::debug!("services: {:?}", services);
let endpoints: Vec<Endpoint> = services.into_iter().flat_map(|s| s.endpoints).collect(); let endpoints: Vec<Endpoint> = services
.into_iter()
.flat_map(|s| s.endpoints)
.filter(|s| s.data.is_some())
.map(|s| Endpoint {
name: s.name,
subject: s.subject,
data: s.data.unwrap(),
})
.collect();
tracing::debug!("endpoints: {:?}", endpoints);
log::trace!( tracing::trace!(
"found {} endpoints for service: {}", "found {} endpoints for service: {}",
endpoints.len(), endpoints.len(),
service_name service_name
...@@ -172,7 +198,7 @@ async fn collect_endpoints( ...@@ -172,7 +198,7 @@ async fn collect_endpoints(
// process endpoints into // process endpoints into
if ep_tx.send(processed).await.is_err() { if ep_tx.send(processed).await.is_err() {
log::trace!("failed to send processed endpoints; shutting down"); tracing::trace!("failed to send processed endpoints; shutting down");
break; break;
} }
} }
......
...@@ -95,9 +95,6 @@ impl KvMetricsPublisher { ...@@ -95,9 +95,6 @@ impl KvMetricsPublisher {
let handler = Ingress::for_engine(handler)?; let handler = Ingress::for_engine(handler)?;
component component
.service_builder()
.create()
.await?
.endpoint("load_metrics") .endpoint("load_metrics")
.endpoint_builder() .endpoint_builder()
.stats_handler(move |_| { .stats_handler(move |_| {
......
...@@ -34,6 +34,13 @@ pub enum KvSchedulerError { ...@@ -34,6 +34,13 @@ pub enum KvSchedulerError {
SubscriberShutdown, SubscriberShutdown,
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FlexibleEndpoint {
pub name: String,
pub subject: String,
pub data: Option<ForwardPassMetrics>,
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Endpoint { pub struct Endpoint {
pub name: String, pub name: String,
...@@ -62,7 +69,7 @@ pub struct Service { ...@@ -62,7 +69,7 @@ pub struct Service {
pub id: String, pub id: String,
pub version: String, pub version: String,
pub started: String, pub started: String,
pub endpoints: Vec<Endpoint>, pub endpoints: Vec<FlexibleEndpoint>,
} }
pub struct SchedulingRequest { pub struct SchedulingRequest {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment