Unverified Commit 66fd6f84 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

refactor: Make `nats_client` optional internally (#3705)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 4116e389
......@@ -32,17 +32,18 @@ pub async fn run(
let cancel_token = distributed_runtime.primary_token().clone();
let endpoint_id: EndpointId = path.parse()?;
let component = distributed_runtime
let mut component = distributed_runtime
.namespace(&endpoint_id.namespace)?
.component(&endpoint_id.component)?;
let endpoint = component
.service_builder()
.create()
.await?
.endpoint(&endpoint_id.name);
let (rt_fut, card): (Pin<Box<dyn Future<Output = _> + Send + 'static>>, _) = match engine_config
{
// We can only make the NATS service if we have NATS
if distributed_runtime.nats_client().is_some() {
// TODO fix in next PR, ServiceConfigBuilder is silly
component = component.service_builder().create().await?;
}
let endpoint = component.endpoint(&endpoint_id.name);
let rt_fut: Pin<Box<dyn Future<Output = _> + Send + 'static>> = match engine_config {
EngineConfig::StaticFull {
engine,
mut model,
......@@ -61,7 +62,7 @@ pub async fn run(
}
let fut_chat = endpoint.endpoint_builder().handler(ingress_chat).start();
(Box::pin(fut_chat), Some(model.card().clone()))
Box::pin(fut_chat)
}
EngineConfig::StaticCore {
engine: inner_engine,
......@@ -92,7 +93,7 @@ pub async fn run(
let fut = endpoint.endpoint_builder().handler(ingress).start();
(Box::pin(fut), Some(model.card().clone()))
Box::pin(fut)
}
EngineConfig::StaticRemote(_) => {
panic!("StaticRemote definitions are only for the frontend end node.");
......@@ -106,7 +107,7 @@ pub async fn run(
// Note: We must return rt_result to propagate the actual error back to the user.
// If we don't return the specific error, the programmer/user won't know what actually
// caused the endpoint service to fail, making debugging much more difficult.
let result = tokio::select! {
tokio::select! {
rt_result = rt_fut => {
tracing::debug!("Endpoint service completed");
match rt_result {
......@@ -124,21 +125,7 @@ pub async fn run(
tracing::debug!("Endpoint service cancelled");
Ok(())
}
};
// If we got an error, return it
result?;
// Cleanup on shutdown
if let Some(mut card) = card
&& let Err(err) = card
.delete_from_nats(distributed_runtime.nats_client())
.await
{
tracing::error!(%err, "delete_from_nats error on shutdown");
}
Ok(())
}
#[cfg(test)]
......
......@@ -25,7 +25,7 @@ use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::storage::key_value_store::{
EtcdStore, Key, KeyValueStore, KeyValueStoreManager,
};
use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned, transports::nats};
use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned};
use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer as HfTokenizer;
......@@ -352,20 +352,6 @@ impl ModelDeploymentCard {
}
}
/// Delete this card from the key-value store and it's URLs from the object store
pub async fn delete_from_nats(&mut self, nats_client: nats::Client) -> Result<()> {
let nats_addr = nats_client.addr();
let bucket_name = self.slug();
tracing::trace!(
nats_addr,
%bucket_name,
"Delete model deployment card from NATS"
);
nats_client
.object_store_delete_bucket(bucket_name.as_ref())
.await
}
/// Allow user to override the name we register this model under.
/// Corresponds to vllm's `--served-model-name`.
pub fn set_name(&mut self, name: &str) {
......
......@@ -190,12 +190,12 @@ pub mod llm_kvbm {
bytes: Vec<u8>,
) -> Result<()> {
let subject = format!("{}.{}", self.subject(), event_name.as_ref());
self.drt()
.nats_client()
.client()
.publish(subject, bytes.into())
.await
.map_err(|e| anyhow::anyhow!("Failed to publish to NATS: {}", e))
let Some(nats_client) = self.drt().nats_client() else {
anyhow::bail!("KVBMDynamoRuntimeComponent EventPublisher requires NATS");
};
nats_client.client().publish(subject, bytes.into()).await?;
Ok(())
}
}
......
......@@ -290,7 +290,9 @@ impl Component {
pub async fn scrape_stats(&self, timeout: Duration) -> Result<ServiceSet> {
// Debug: scraping stats for component
let service_name = self.service_name();
let service_client = self.drt().service_client();
let Some(service_client) = self.drt().service_client() else {
anyhow::bail!("ServiceSet is gathered via NATS, do not call this in non-NATS setups.");
};
service_client
.collect_services(&service_name, timeout)
.await
......
......@@ -31,12 +31,11 @@ impl EventPublisher for Component {
bytes: Vec<u8>,
) -> Result<()> {
let subject = format!("{}.{}", self.subject(), event_name.as_ref());
Ok(self
.drt()
.nats_client()
.client()
.publish(subject, bytes.into())
.await?)
let Some(nats_client) = self.drt().nats_client() else {
anyhow::bail!("KV router's EventPublisher requires NATS");
};
nats_client.client().publish(subject, bytes.into()).await?;
Ok(())
}
}
......@@ -47,7 +46,10 @@ impl EventSubscriber for Component {
event_name: impl AsRef<str> + Send + Sync,
) -> Result<async_nats::Subscriber> {
let subject = format!("{}.{}", self.subject(), event_name.as_ref());
Ok(self.drt().nats_client().client().subscribe(subject).await?)
let Some(nats_client) = self.drt().nats_client() else {
anyhow::bail!("KV router's EventSubscriber requires NATS");
};
Ok(nats_client.client().subscribe(subject).await?)
}
async fn subscribe_with_type<T: for<'de> Deserialize<'de> + Send + 'static>(
......
......@@ -31,12 +31,11 @@ impl EventPublisher for Namespace {
bytes: Vec<u8>,
) -> Result<()> {
let subject = format!("{}.{}", self.subject(), event_name.as_ref());
Ok(self
.drt()
.nats_client()
.client()
.publish(subject, bytes.into())
.await?)
let Some(nats_client) = self.drt().nats_client() else {
anyhow::bail!("KV router's Namespace EventPublisher requires NATS");
};
nats_client.client().publish(subject, bytes.into()).await?;
Ok(())
}
}
......@@ -47,7 +46,10 @@ impl EventSubscriber for Namespace {
event_name: impl AsRef<str> + Send + Sync,
) -> Result<async_nats::Subscriber> {
let subject = format!("{}.{}", self.subject(), event_name.as_ref());
Ok(self.drt().nats_client().client().subscribe(subject).await?)
let Some(nats_client) = self.drt().nats_client() else {
anyhow::bail!("KV router's Namespace EventSubscriber requires NATS");
};
Ok(nats_client.client().subscribe(subject).await?)
}
async fn subscribe_with_type<T: for<'de> Deserialize<'de> + Send + 'static>(
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use async_nats::service::Service as NatsService;
use async_nats::service::ServiceExt as _;
use derive_builder::Builder;
use derive_getters::Dissolve;
use std::collections::HashMap;
use std::sync::Mutex;
use std::sync::{Arc, Mutex};
use super::*;
use crate::component::Component;
pub use super::endpoint::EndpointStats;
use educe::Educe;
type StatsHandlerRegistry = Arc<Mutex<HashMap<String, EndpointStatsHandler>>>;
pub type StatsHandler =
Box<dyn FnMut(String, EndpointStats) -> serde_json::Value + Send + Sync + 'static>;
pub type EndpointStatsHandler =
Box<dyn FnMut(EndpointStats) -> serde_json::Value + Send + Sync + 'static>;
pub const PROJECT_NAME: &str = "Dynamo";
const SERVICE_VERSION: &str = env!("CARGO_PKG_VERSION");
#[derive(Educe, Builder, Dissolve)]
#[educe(Debug)]
......@@ -29,75 +37,91 @@ pub struct ServiceConfig {
impl ServiceConfigBuilder {
/// Create the [`Component`]'s service and store it in the registry.
pub async fn create(self) -> Result<Component> {
pub async fn create(self) -> anyhow::Result<Component> {
let (component, description) = self.build_internal()?.dissolve();
let version = "0.0.1".to_string();
let service_name = component.service_name();
// Pre-check to save cost of creating the service, but don't hold the lock
if component
.drt
.component_registry
.inner
.lock()
.await
.services
.contains_key(&service_name)
{
anyhow::bail!("Service {service_name} already exists");
}
let Some(nats_client) = component.drt.nats_client() else {
anyhow::bail!("Cannot create NATS service without NATS.");
};
let (nats_service, stats_reg) =
build_nats_service(nats_client, &component, description).await?;
let mut guard = component.drt.component_registry.inner.lock().await;
if !guard.services.contains_key(&service_name) {
// Normal case
guard.services.insert(service_name.clone(), nats_service);
guard.stats_handlers.insert(service_name, stats_reg);
drop(guard);
} else {
drop(guard);
let _ = nats_service.stop().await;
return Err(anyhow::anyhow!(
"Service create race for {service_name}, now already exists"
));
}
// Register metrics callback. CRITICAL: Never fail service creation for metrics issues.
if let Err(err) = component.start_scraping_nats_service_component_metrics() {
tracing::debug!(
"Metrics registration failed for '{}': {}",
component.service_name(),
err
);
}
Ok(component)
}
}
async fn build_nats_service(
nats_client: &crate::transports::nats::Client,
component: &Component,
description: Option<String>,
) -> anyhow::Result<(NatsService, StatsHandlerRegistry)> {
let service_name = component.service_name();
log::debug!("component: {component}; creating, service_name: {service_name}");
tracing::trace!("component: {component}; creating, service_name: {service_name}");
let description = description.unwrap_or(format!(
"{PROJECT_NAME} component {} in namespace {}",
component.name, component.namespace
));
let stats_handler_registry: Arc<Mutex<HashMap<String, EndpointStatsHandler>>> =
Arc::new(Mutex::new(HashMap::new()));
let stats_handler_registry: StatsHandlerRegistry = Arc::new(Mutex::new(HashMap::new()));
let stats_handler_registry_clone = stats_handler_registry.clone();
let mut guard = component.drt.component_registry.inner.lock().await;
if guard.services.contains_key(&service_name) {
return Err(anyhow::anyhow!("Service already exists"));
}
// create service on the secondary runtime
let builder = component.drt.nats_client.client().service_builder();
let nats_service_builder = nats_client.client().service_builder();
tracing::debug!("Starting service: {}", service_name);
let service_builder = builder
let nats_service_builder =
nats_service_builder
.description(description)
.stats_handler(move |name, stats| {
log::trace!("stats_handler: {name}, {stats:?}");
tracing::trace!("stats_handler: {name}, {stats:?}");
let mut guard = stats_handler_registry.lock().unwrap();
match guard.get_mut(&name) {
Some(handler) => handler(stats),
None => serde_json::Value::Null,
}
});
tracing::debug!("Got builder");
let service = service_builder
.start(service_name.clone(), version)
let nats_service = nats_service_builder
.start(service_name, SERVICE_VERSION.to_string())
.await
.map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?;
// new copy of service_name as the previous one is moved into the task above
let service_name = component.service_name();
// insert the service into the registry
guard.services.insert(service_name.clone(), service);
// insert the stats handler into the registry
guard
.stats_handlers
.insert(service_name, stats_handler_registry_clone);
// drop the guard to unlock the mutex
drop(guard);
// Register metrics callback. CRITICAL: Never fail service creation for metrics issues.
if let Err(err) = component.start_scraping_nats_service_component_metrics() {
tracing::debug!(
"Metrics registration failed for '{}': {}",
component.service_name(),
err
);
}
.map_err(|e| anyhow::anyhow!("Failed to start NATS service: {e}"))?;
Ok(component)
}
Ok((nats_service, stats_handler_registry_clone))
}
impl ServiceConfigBuilder {
......
......@@ -55,7 +55,7 @@ impl DistributedRuntime {
(Some(etcd_client), store)
};
let nats_client = nats_config.clone().connect().await?;
let nats_client = Some(nats_config.clone().connect().await?);
// Start system status server for health and metrics if enabled in configuration
let config = crate::config::RuntimeConfig::from_settings().unwrap_or_default();
......@@ -96,6 +96,7 @@ impl DistributedRuntime {
system_health,
};
if let Some(nats_client_for_metrics) = nats_client_for_metrics {
let nats_client_metrics = DRTNatsClientPrometheusMetrics::new(
&distributed_runtime,
nats_client_for_metrics.client().clone(),
......@@ -112,6 +113,7 @@ impl DistributedRuntime {
});
distributed_runtime
.register_prometheus_update_callback(drt_hierarchies, nats_client_callback);
}
// Initialize the uptime gauge in SystemHealth
distributed_runtime
......@@ -245,8 +247,8 @@ impl DistributedRuntime {
)
}
pub(crate) fn service_client(&self) -> ServiceClient {
ServiceClient::new(self.nats_client.clone())
pub(crate) fn service_client(&self) -> Option<ServiceClient> {
self.nats_client().map(|nc| ServiceClient::new(nc.clone()))
}
pub async fn tcp_server(&self) -> Result<Arc<tcp::server::TcpStreamServer>> {
......@@ -261,8 +263,8 @@ impl DistributedRuntime {
.clone())
}
pub fn nats_client(&self) -> nats::Client {
self.nats_client.clone()
pub fn nats_client(&self) -> Option<&nats::Client> {
self.nats_client.as_ref()
}
/// Get system status server information if available
......
......@@ -187,7 +187,7 @@ pub struct DistributedRuntime {
// we might consider a unifed transport manager here
etcd_client: Option<transports::etcd::Client>,
nats_client: transports::nats::Client,
nats_client: Option<transports::nats::Client>,
store: Arc<dyn KeyValueStore>,
tcp_server: Arc<OnceCell<Arc<transports::tcp::server::TcpStreamServer>>>,
system_status_server: Arc<OnceLock<Arc<system_status_server::SystemStatusServerInfo>>>,
......
......@@ -89,8 +89,11 @@ impl RouterMode {
}
async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
let Some(nats_client) = endpoint.drt().nats_client() else {
anyhow::bail!("Missing NATS. Please ensure it is running and accessible.");
};
AddressedPushRouter::new(
endpoint.drt().nats_client.client().clone(),
nats_client.client().clone(),
endpoint.drt().tcp_server().await?,
)
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment