Commit 11a36651 authored by Alec's avatar Alec Committed by GitHub
Browse files

[fix] KV Router Example fixes (#314)


Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent d38325c2
......@@ -67,7 +67,7 @@ class Router:
# catch specific router exceptions once we have dedicated types.
except Exception as e:
vllm_logger.info(f"{e}")
worker_id = None
worker_id = ""
vllm_logger.exception(f"Error during worker selection: {e}")
vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
......
......@@ -82,9 +82,8 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
"""
Serve the triton-init.vllm.generate endpoint.
"""
metrics_publisher = KvMetricsPublisher()
worker_component = runtime.namespace("triton-init").component("vllm")
await metrics_publisher.create_service(worker_component)
await worker_component.create_service()
worker_endpoint = worker_component.endpoint("generate")
......@@ -98,6 +97,7 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
VLLM_KV_COMPONENT = "vllm"
os.environ["VLLM_KV_COMPONENT"] = str(VLLM_KV_COMPONENT)
metrics_publisher = KvMetricsPublisher()
vllm_engine = VllmEngine(engine_args, metrics_publisher)
await vllm_engine.initialize()
# Initially send dummy metrics to kick start,
......@@ -109,7 +109,10 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
1024,
)
await worker_endpoint.serve_endpoint(vllm_engine.generate)
await asyncio.gather(
worker_endpoint.serve_endpoint(vllm_engine.generate),
metrics_publisher.create_endpoint(worker_component),
)
if __name__ == "__main__":
......
......@@ -69,7 +69,7 @@ impl KvMetricsPublisher {
})
}
fn create_service<'p>(
fn create_endpoint<'p>(
&self,
py: Python<'p>,
component: Component,
......
......@@ -17,7 +17,7 @@ use anyhow::Result;
use futures::stream::StreamExt;
use std::{sync::Arc, time::Duration};
use tokio_util::sync::CancellationToken;
use tracing as log;
use tracing;
use triton_distributed_runtime::{component::Component, DistributedRuntime};
pub mod indexer;
......@@ -56,8 +56,8 @@ impl KvRouter {
let nats_client = runtime.nats_client();
let service_name = backend.service_name();
let kv_subject = backend.event_subject(KV_EVENT_SUBJECT);
log::info!("Component Service Name {}", service_name);
log::info!("KV Subject {}", kv_subject);
tracing::info!("Component Service Name {}", service_name);
tracing::info!("KV Subject {}", kv_subject);
Self::new(nats_client, service_name, kv_subject).await
}
......@@ -79,16 +79,26 @@ impl KvRouter {
let indexer = KvIndexer::new(cancellation_token.clone());
let scheduler = KvScheduler::start(ep_rx).await?;
log::debug!("subscribing to kv events: {}", kv_subject);
tracing::debug!("subscribing to kv events: {}", kv_subject);
let mut kv_events_rx = nats_client.client().subscribe(kv_subject).await?;
let kv_events_tx = indexer.event_sender();
tokio::spawn(async move {
while let Some(event) = kv_events_rx.next().await {
let event: RouterEvent = serde_json::from_slice(&event.payload).unwrap();
log::debug!("received kv event: {:?}", event);
let event: RouterEvent = match serde_json::from_slice(&event.payload) {
Ok(event) => {
tracing::debug!("received kv event: {:?}", event);
event
}
Err(e) => {
tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
// Choosing warn and continue to process other events from other workers
// A bad event likely signals a problem with a worker, but potentially other workers are still healthy
continue;
}
};
if let Err(e) = kv_events_tx.send(event).await {
log::trace!("failed to send kv event to indexer; shutting down: {:?}", e);
tracing::trace!("failed to send kv event to indexer; shutting down: {:?}", e);
}
}
});
......@@ -118,7 +128,7 @@ impl KvRouter {
.indexer
.find_matches_for_request(token_ids.as_slice())
.await?;
log::debug!("KV router overlap_scores: {:?}", overlap_scores);
tracing::debug!("KV router overlap_scores: {:?}", overlap_scores);
let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
Ok(worker_id)
}
......@@ -133,36 +143,52 @@ async fn collect_endpoints(
loop {
tokio::select! {
_ = cancel.cancelled() => {
log::debug!("cancellation token triggered");
tracing::debug!("cancellation token triggered");
break;
}
_ = tokio::time::sleep(Duration::from_secs(1)) => {
log::trace!("collecting endpoints for service: {}", service_name);
tracing::trace!("collecting endpoints for service: {}", service_name);
}
}
let values = nats_client
let values = match nats_client
.get_endpoints(&service_name, Duration::from_secs(1))
.await
.unwrap();
{
Ok(v) => v,
Err(e) => {
tracing::warn!("Failed to retrieve endpoints for {}: {:?}", service_name, e);
continue;
}
};
// [FIXME] Endpoint is parsed from nats stats handler which may not include 'data' field
// if the service hasn't registered the handler. Need to be tolerant to this.
// Another option is to make sure the router is configured properly that
// it listens to the right subject (where other publisher has stats).
tracing::debug!("values: {:?}", values);
let services: Vec<Service> = values
.into_iter()
.filter(|v| !v.is_empty())
.map(|v| {
let value: serde_json::Value = serde_json::from_slice(&v).unwrap();
log::trace!("service value: {:?}", value);
serde_json::from_slice(&v).unwrap()
.filter_map(|v| match serde_json::from_slice::<Service>(&v) {
Ok(service) => Some(service),
Err(e) => {
tracing::warn!("For value: {:?} \nFailed to parse service: {:?}", v, e);
None
}
})
.collect();
tracing::debug!("services: {:?}", services);
let endpoints: Vec<Endpoint> = services.into_iter().flat_map(|s| s.endpoints).collect();
let endpoints: Vec<Endpoint> = services
.into_iter()
.flat_map(|s| s.endpoints)
.filter(|s| s.data.is_some())
.map(|s| Endpoint {
name: s.name,
subject: s.subject,
data: s.data.unwrap(),
})
.collect();
tracing::debug!("endpoints: {:?}", endpoints);
log::trace!(
tracing::trace!(
"found {} endpoints for service: {}",
endpoints.len(),
service_name
......@@ -172,7 +198,7 @@ async fn collect_endpoints(
// process endpoints into
if ep_tx.send(processed).await.is_err() {
log::trace!("failed to send processed endpoints; shutting down");
tracing::trace!("failed to send processed endpoints; shutting down");
break;
}
}
......
......@@ -95,9 +95,6 @@ impl KvMetricsPublisher {
let handler = Ingress::for_engine(handler)?;
component
.service_builder()
.create()
.await?
.endpoint("load_metrics")
.endpoint_builder()
.stats_handler(move |_| {
......
......@@ -34,6 +34,13 @@ pub enum KvSchedulerError {
SubscriberShutdown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FlexibleEndpoint {
pub name: String,
pub subject: String,
pub data: Option<ForwardPassMetrics>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Endpoint {
pub name: String,
......@@ -62,7 +69,7 @@ pub struct Service {
pub id: String,
pub version: String,
pub started: String,
pub endpoints: Vec<Endpoint>,
pub endpoints: Vec<FlexibleEndpoint>,
}
pub struct SchedulingRequest {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment