Unverified Commit 66fd6f84 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

refactor: Make `nats_client` optional internally (#3705)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 4116e389
...@@ -32,17 +32,18 @@ pub async fn run( ...@@ -32,17 +32,18 @@ pub async fn run(
let cancel_token = distributed_runtime.primary_token().clone(); let cancel_token = distributed_runtime.primary_token().clone();
let endpoint_id: EndpointId = path.parse()?; let endpoint_id: EndpointId = path.parse()?;
let component = distributed_runtime let mut component = distributed_runtime
.namespace(&endpoint_id.namespace)? .namespace(&endpoint_id.namespace)?
.component(&endpoint_id.component)?; .component(&endpoint_id.component)?;
let endpoint = component
.service_builder() // We can only make the NATS service if we have NATS
.create() if distributed_runtime.nats_client().is_some() {
.await? // TODO fix in next PR, ServiceConfigBuilder is silly
.endpoint(&endpoint_id.name); component = component.service_builder().create().await?;
}
let (rt_fut, card): (Pin<Box<dyn Future<Output = _> + Send + 'static>>, _) = match engine_config let endpoint = component.endpoint(&endpoint_id.name);
{
let rt_fut: Pin<Box<dyn Future<Output = _> + Send + 'static>> = match engine_config {
EngineConfig::StaticFull { EngineConfig::StaticFull {
engine, engine,
mut model, mut model,
...@@ -61,7 +62,7 @@ pub async fn run( ...@@ -61,7 +62,7 @@ pub async fn run(
} }
let fut_chat = endpoint.endpoint_builder().handler(ingress_chat).start(); let fut_chat = endpoint.endpoint_builder().handler(ingress_chat).start();
(Box::pin(fut_chat), Some(model.card().clone())) Box::pin(fut_chat)
} }
EngineConfig::StaticCore { EngineConfig::StaticCore {
engine: inner_engine, engine: inner_engine,
...@@ -92,7 +93,7 @@ pub async fn run( ...@@ -92,7 +93,7 @@ pub async fn run(
let fut = endpoint.endpoint_builder().handler(ingress).start(); let fut = endpoint.endpoint_builder().handler(ingress).start();
(Box::pin(fut), Some(model.card().clone())) Box::pin(fut)
} }
EngineConfig::StaticRemote(_) => { EngineConfig::StaticRemote(_) => {
panic!("StaticRemote definitions are only for the frontend end node."); panic!("StaticRemote definitions are only for the frontend end node.");
...@@ -106,7 +107,7 @@ pub async fn run( ...@@ -106,7 +107,7 @@ pub async fn run(
// Note: We must return rt_result to propagate the actual error back to the user. // Note: We must return rt_result to propagate the actual error back to the user.
// If we don't return the specific error, the programmer/user won't know what actually // If we don't return the specific error, the programmer/user won't know what actually
// caused the endpoint service to fail, making debugging much more difficult. // caused the endpoint service to fail, making debugging much more difficult.
let result = tokio::select! { tokio::select! {
rt_result = rt_fut => { rt_result = rt_fut => {
tracing::debug!("Endpoint service completed"); tracing::debug!("Endpoint service completed");
match rt_result { match rt_result {
...@@ -124,21 +125,7 @@ pub async fn run( ...@@ -124,21 +125,7 @@ pub async fn run(
tracing::debug!("Endpoint service cancelled"); tracing::debug!("Endpoint service cancelled");
Ok(()) Ok(())
} }
};
// If we got an error, return it
result?;
// Cleanup on shutdown
if let Some(mut card) = card
&& let Err(err) = card
.delete_from_nats(distributed_runtime.nats_client())
.await
{
tracing::error!(%err, "delete_from_nats error on shutdown");
} }
Ok(())
} }
#[cfg(test)] #[cfg(test)]
......
...@@ -25,7 +25,7 @@ use dynamo_runtime::DistributedRuntime; ...@@ -25,7 +25,7 @@ use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::storage::key_value_store::{ use dynamo_runtime::storage::key_value_store::{
EtcdStore, Key, KeyValueStore, KeyValueStoreManager, EtcdStore, Key, KeyValueStore, KeyValueStoreManager,
}; };
use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned, transports::nats}; use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer as HfTokenizer; use tokenizers::Tokenizer as HfTokenizer;
...@@ -352,20 +352,6 @@ impl ModelDeploymentCard { ...@@ -352,20 +352,6 @@ impl ModelDeploymentCard {
} }
} }
/// Delete this card from the key-value store and it's URLs from the object store
pub async fn delete_from_nats(&mut self, nats_client: nats::Client) -> Result<()> {
let nats_addr = nats_client.addr();
let bucket_name = self.slug();
tracing::trace!(
nats_addr,
%bucket_name,
"Delete model deployment card from NATS"
);
nats_client
.object_store_delete_bucket(bucket_name.as_ref())
.await
}
/// Allow user to override the name we register this model under. /// Allow user to override the name we register this model under.
/// Corresponds to vllm's `--served-model-name`. /// Corresponds to vllm's `--served-model-name`.
pub fn set_name(&mut self, name: &str) { pub fn set_name(&mut self, name: &str) {
......
...@@ -190,12 +190,12 @@ pub mod llm_kvbm { ...@@ -190,12 +190,12 @@ pub mod llm_kvbm {
bytes: Vec<u8>, bytes: Vec<u8>,
) -> Result<()> { ) -> Result<()> {
let subject = format!("{}.{}", self.subject(), event_name.as_ref()); let subject = format!("{}.{}", self.subject(), event_name.as_ref());
self.drt()
.nats_client() let Some(nats_client) = self.drt().nats_client() else {
.client() anyhow::bail!("KVBMDynamoRuntimeComponent EventPublisher requires NATS");
.publish(subject, bytes.into()) };
.await nats_client.client().publish(subject, bytes.into()).await?;
.map_err(|e| anyhow::anyhow!("Failed to publish to NATS: {}", e)) Ok(())
} }
} }
......
...@@ -290,7 +290,9 @@ impl Component { ...@@ -290,7 +290,9 @@ impl Component {
pub async fn scrape_stats(&self, timeout: Duration) -> Result<ServiceSet> { pub async fn scrape_stats(&self, timeout: Duration) -> Result<ServiceSet> {
// Debug: scraping stats for component // Debug: scraping stats for component
let service_name = self.service_name(); let service_name = self.service_name();
let service_client = self.drt().service_client(); let Some(service_client) = self.drt().service_client() else {
anyhow::bail!("ServiceSet is gathered via NATS, do not call this in non-NATS setups.");
};
service_client service_client
.collect_services(&service_name, timeout) .collect_services(&service_name, timeout)
.await .await
......
...@@ -31,12 +31,11 @@ impl EventPublisher for Component { ...@@ -31,12 +31,11 @@ impl EventPublisher for Component {
bytes: Vec<u8>, bytes: Vec<u8>,
) -> Result<()> { ) -> Result<()> {
let subject = format!("{}.{}", self.subject(), event_name.as_ref()); let subject = format!("{}.{}", self.subject(), event_name.as_ref());
Ok(self let Some(nats_client) = self.drt().nats_client() else {
.drt() anyhow::bail!("KV router's EventPublisher requires NATS");
.nats_client() };
.client() nats_client.client().publish(subject, bytes.into()).await?;
.publish(subject, bytes.into()) Ok(())
.await?)
} }
} }
...@@ -47,7 +46,10 @@ impl EventSubscriber for Component { ...@@ -47,7 +46,10 @@ impl EventSubscriber for Component {
event_name: impl AsRef<str> + Send + Sync, event_name: impl AsRef<str> + Send + Sync,
) -> Result<async_nats::Subscriber> { ) -> Result<async_nats::Subscriber> {
let subject = format!("{}.{}", self.subject(), event_name.as_ref()); let subject = format!("{}.{}", self.subject(), event_name.as_ref());
Ok(self.drt().nats_client().client().subscribe(subject).await?) let Some(nats_client) = self.drt().nats_client() else {
anyhow::bail!("KV router's EventSubscriber requires NATS");
};
Ok(nats_client.client().subscribe(subject).await?)
} }
async fn subscribe_with_type<T: for<'de> Deserialize<'de> + Send + 'static>( async fn subscribe_with_type<T: for<'de> Deserialize<'de> + Send + 'static>(
......
...@@ -31,12 +31,11 @@ impl EventPublisher for Namespace { ...@@ -31,12 +31,11 @@ impl EventPublisher for Namespace {
bytes: Vec<u8>, bytes: Vec<u8>,
) -> Result<()> { ) -> Result<()> {
let subject = format!("{}.{}", self.subject(), event_name.as_ref()); let subject = format!("{}.{}", self.subject(), event_name.as_ref());
Ok(self let Some(nats_client) = self.drt().nats_client() else {
.drt() anyhow::bail!("KV router's Namespace EventPublisher requires NATS");
.nats_client() };
.client() nats_client.client().publish(subject, bytes.into()).await?;
.publish(subject, bytes.into()) Ok(())
.await?)
} }
} }
...@@ -47,7 +46,10 @@ impl EventSubscriber for Namespace { ...@@ -47,7 +46,10 @@ impl EventSubscriber for Namespace {
event_name: impl AsRef<str> + Send + Sync, event_name: impl AsRef<str> + Send + Sync,
) -> Result<async_nats::Subscriber> { ) -> Result<async_nats::Subscriber> {
let subject = format!("{}.{}", self.subject(), event_name.as_ref()); let subject = format!("{}.{}", self.subject(), event_name.as_ref());
Ok(self.drt().nats_client().client().subscribe(subject).await?) let Some(nats_client) = self.drt().nats_client() else {
anyhow::bail!("KV router's Namespace EventSubscriber requires NATS");
};
Ok(nats_client.client().subscribe(subject).await?)
} }
async fn subscribe_with_type<T: for<'de> Deserialize<'de> + Send + 'static>( async fn subscribe_with_type<T: for<'de> Deserialize<'de> + Send + 'static>(
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use async_nats::service::Service as NatsService;
use async_nats::service::ServiceExt as _;
use derive_builder::Builder;
use derive_getters::Dissolve; use derive_getters::Dissolve;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Mutex; use std::sync::{Arc, Mutex};
use super::*; use crate::component::Component;
pub use super::endpoint::EndpointStats; pub use super::endpoint::EndpointStats;
use educe::Educe;
type StatsHandlerRegistry = Arc<Mutex<HashMap<String, EndpointStatsHandler>>>;
pub type StatsHandler = pub type StatsHandler =
Box<dyn FnMut(String, EndpointStats) -> serde_json::Value + Send + Sync + 'static>; Box<dyn FnMut(String, EndpointStats) -> serde_json::Value + Send + Sync + 'static>;
pub type EndpointStatsHandler = pub type EndpointStatsHandler =
Box<dyn FnMut(EndpointStats) -> serde_json::Value + Send + Sync + 'static>; Box<dyn FnMut(EndpointStats) -> serde_json::Value + Send + Sync + 'static>;
pub const PROJECT_NAME: &str = "Dynamo"; pub const PROJECT_NAME: &str = "Dynamo";
const SERVICE_VERSION: &str = env!("CARGO_PKG_VERSION");
#[derive(Educe, Builder, Dissolve)] #[derive(Educe, Builder, Dissolve)]
#[educe(Debug)] #[educe(Debug)]
...@@ -29,75 +37,91 @@ pub struct ServiceConfig { ...@@ -29,75 +37,91 @@ pub struct ServiceConfig {
impl ServiceConfigBuilder { impl ServiceConfigBuilder {
/// Create the [`Component`]'s service and store it in the registry. /// Create the [`Component`]'s service and store it in the registry.
pub async fn create(self) -> Result<Component> { pub async fn create(self) -> anyhow::Result<Component> {
let (component, description) = self.build_internal()?.dissolve(); let (component, description) = self.build_internal()?.dissolve();
let version = "0.0.1".to_string();
let service_name = component.service_name(); let service_name = component.service_name();
log::debug!("component: {component}; creating, service_name: {service_name}");
let description = description.unwrap_or(format!( // Pre-check to save cost of creating the service, but don't hold the lock
"{PROJECT_NAME} component {} in namespace {}", if component
component.name, component.namespace .drt
)); .component_registry
.inner
let stats_handler_registry: Arc<Mutex<HashMap<String, EndpointStatsHandler>>> = .lock()
Arc::new(Mutex::new(HashMap::new())); .await
.services
.contains_key(&service_name)
{
anyhow::bail!("Service {service_name} already exists");
}
let stats_handler_registry_clone = stats_handler_registry.clone(); let Some(nats_client) = component.drt.nats_client() else {
anyhow::bail!("Cannot create NATS service without NATS.");
};
let (nats_service, stats_reg) =
build_nats_service(nats_client, &component, description).await?;
let mut guard = component.drt.component_registry.inner.lock().await; let mut guard = component.drt.component_registry.inner.lock().await;
if !guard.services.contains_key(&service_name) {
// Normal case
guard.services.insert(service_name.clone(), nats_service);
guard.stats_handlers.insert(service_name, stats_reg);
drop(guard);
} else {
drop(guard);
let _ = nats_service.stop().await;
return Err(anyhow::anyhow!(
"Service create race for {service_name}, now already exists"
));
}
if guard.services.contains_key(&service_name) { // Register metrics callback. CRITICAL: Never fail service creation for metrics issues.
return Err(anyhow::anyhow!("Service already exists")); if let Err(err) = component.start_scraping_nats_service_component_metrics() {
tracing::debug!(
"Metrics registration failed for '{}': {}",
component.service_name(),
err
);
} }
Ok(component)
}
}
async fn build_nats_service(
nats_client: &crate::transports::nats::Client,
component: &Component,
description: Option<String>,
) -> anyhow::Result<(NatsService, StatsHandlerRegistry)> {
let service_name = component.service_name();
tracing::trace!("component: {component}; creating, service_name: {service_name}");
// create service on the secondary runtime let description = description.unwrap_or(format!(
let builder = component.drt.nats_client.client().service_builder(); "{PROJECT_NAME} component {} in namespace {}",
component.name, component.namespace
));
tracing::debug!("Starting service: {}", service_name); let stats_handler_registry: StatsHandlerRegistry = Arc::new(Mutex::new(HashMap::new()));
let service_builder = builder let stats_handler_registry_clone = stats_handler_registry.clone();
let nats_service_builder = nats_client.client().service_builder();
let nats_service_builder =
nats_service_builder
.description(description) .description(description)
.stats_handler(move |name, stats| { .stats_handler(move |name, stats| {
log::trace!("stats_handler: {name}, {stats:?}"); tracing::trace!("stats_handler: {name}, {stats:?}");
let mut guard = stats_handler_registry.lock().unwrap(); let mut guard = stats_handler_registry.lock().unwrap();
match guard.get_mut(&name) { match guard.get_mut(&name) {
Some(handler) => handler(stats), Some(handler) => handler(stats),
None => serde_json::Value::Null, None => serde_json::Value::Null,
} }
}); });
tracing::debug!("Got builder"); let nats_service = nats_service_builder
let service = service_builder .start(service_name, SERVICE_VERSION.to_string())
.start(service_name.clone(), version) .await
.await .map_err(|e| anyhow::anyhow!("Failed to start NATS service: {e}"))?;
.map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?;
// new copy of service_name as the previous one is moved into the task above
let service_name = component.service_name();
// insert the service into the registry Ok((nats_service, stats_handler_registry_clone))
guard.services.insert(service_name.clone(), service);
// insert the stats handler into the registry
guard
.stats_handlers
.insert(service_name, stats_handler_registry_clone);
// drop the guard to unlock the mutex
drop(guard);
// Register metrics callback. CRITICAL: Never fail service creation for metrics issues.
if let Err(err) = component.start_scraping_nats_service_component_metrics() {
tracing::debug!(
"Metrics registration failed for '{}': {}",
component.service_name(),
err
);
}
Ok(component)
}
} }
impl ServiceConfigBuilder { impl ServiceConfigBuilder {
......
...@@ -55,7 +55,7 @@ impl DistributedRuntime { ...@@ -55,7 +55,7 @@ impl DistributedRuntime {
(Some(etcd_client), store) (Some(etcd_client), store)
}; };
let nats_client = nats_config.clone().connect().await?; let nats_client = Some(nats_config.clone().connect().await?);
// Start system status server for health and metrics if enabled in configuration // Start system status server for health and metrics if enabled in configuration
let config = crate::config::RuntimeConfig::from_settings().unwrap_or_default(); let config = crate::config::RuntimeConfig::from_settings().unwrap_or_default();
...@@ -96,22 +96,24 @@ impl DistributedRuntime { ...@@ -96,22 +96,24 @@ impl DistributedRuntime {
system_health, system_health,
}; };
let nats_client_metrics = DRTNatsClientPrometheusMetrics::new( if let Some(nats_client_for_metrics) = nats_client_for_metrics {
&distributed_runtime, let nats_client_metrics = DRTNatsClientPrometheusMetrics::new(
nats_client_for_metrics.client().clone(), &distributed_runtime,
)?; nats_client_for_metrics.client().clone(),
let mut drt_hierarchies = distributed_runtime.parent_hierarchy(); )?;
drt_hierarchies.push(distributed_runtime.hierarchy()); let mut drt_hierarchies = distributed_runtime.parent_hierarchy();
// Register a callback to update NATS client metrics drt_hierarchies.push(distributed_runtime.hierarchy());
let nats_client_callback = Arc::new({ // Register a callback to update NATS client metrics
let nats_client_clone = nats_client_metrics.clone(); let nats_client_callback = Arc::new({
move || { let nats_client_clone = nats_client_metrics.clone();
nats_client_clone.set_from_client_stats(); move || {
Ok(()) nats_client_clone.set_from_client_stats();
} Ok(())
}); }
distributed_runtime });
.register_prometheus_update_callback(drt_hierarchies, nats_client_callback); distributed_runtime
.register_prometheus_update_callback(drt_hierarchies, nats_client_callback);
}
// Initialize the uptime gauge in SystemHealth // Initialize the uptime gauge in SystemHealth
distributed_runtime distributed_runtime
...@@ -245,8 +247,8 @@ impl DistributedRuntime { ...@@ -245,8 +247,8 @@ impl DistributedRuntime {
) )
} }
pub(crate) fn service_client(&self) -> ServiceClient { pub(crate) fn service_client(&self) -> Option<ServiceClient> {
ServiceClient::new(self.nats_client.clone()) self.nats_client().map(|nc| ServiceClient::new(nc.clone()))
} }
pub async fn tcp_server(&self) -> Result<Arc<tcp::server::TcpStreamServer>> { pub async fn tcp_server(&self) -> Result<Arc<tcp::server::TcpStreamServer>> {
...@@ -261,8 +263,8 @@ impl DistributedRuntime { ...@@ -261,8 +263,8 @@ impl DistributedRuntime {
.clone()) .clone())
} }
pub fn nats_client(&self) -> nats::Client { pub fn nats_client(&self) -> Option<&nats::Client> {
self.nats_client.clone() self.nats_client.as_ref()
} }
/// Get system status server information if available /// Get system status server information if available
......
...@@ -187,7 +187,7 @@ pub struct DistributedRuntime { ...@@ -187,7 +187,7 @@ pub struct DistributedRuntime {
// we might consider a unifed transport manager here // we might consider a unifed transport manager here
etcd_client: Option<transports::etcd::Client>, etcd_client: Option<transports::etcd::Client>,
nats_client: transports::nats::Client, nats_client: Option<transports::nats::Client>,
store: Arc<dyn KeyValueStore>, store: Arc<dyn KeyValueStore>,
tcp_server: Arc<OnceCell<Arc<transports::tcp::server::TcpStreamServer>>>, tcp_server: Arc<OnceCell<Arc<transports::tcp::server::TcpStreamServer>>>,
system_status_server: Arc<OnceLock<Arc<system_status_server::SystemStatusServerInfo>>>, system_status_server: Arc<OnceLock<Arc<system_status_server::SystemStatusServerInfo>>>,
......
...@@ -89,8 +89,11 @@ impl RouterMode { ...@@ -89,8 +89,11 @@ impl RouterMode {
} }
async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> { async fn addressed_router(endpoint: &Endpoint) -> anyhow::Result<Arc<AddressedPushRouter>> {
let Some(nats_client) = endpoint.drt().nats_client() else {
anyhow::bail!("Missing NATS. Please ensure it is running and accessible.");
};
AddressedPushRouter::new( AddressedPushRouter::new(
endpoint.drt().nats_client.client().clone(), nats_client.client().clone(),
endpoint.drt().tcp_server().await?, endpoint.drt().tcp_server().await?,
) )
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment