Unverified Commit 69797b5a authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Only monitor NATS metrics if using NATS request plane (#4442)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent a8e5328e
......@@ -44,7 +44,6 @@ async def init(runtime: DistributedRuntime, ns: str):
A `Component` can serve multiple endpoints
"""
component = runtime.namespace(ns).component("backend")
await component.create_service()
endpoint = component.endpoint("generate")
print("Started server instance")
......
......@@ -49,7 +49,6 @@ async def init(runtime: DistributedRuntime, ns: str):
A `Component` can serve multiple endpoints
"""
component = runtime.namespace(ns).component("backend")
await component.create_service()
endpoint = component.endpoint("generate")
print("Started server instance")
......
......@@ -90,7 +90,6 @@ async def init(runtime: DistributedRuntime, config: Config):
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
endpoint = component.endpoint(config.endpoint)
await register_llm(
......
......@@ -103,7 +103,6 @@ async def init(runtime: DistributedRuntime, config: Config):
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
endpoint = component.endpoint(config.endpoint)
await register_llm(
......
......@@ -101,7 +101,6 @@ async def init(runtime: DistributedRuntime, config: Config):
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
endpoint = component.endpoint(config.endpoint)
await register_llm(
......
......@@ -36,7 +36,6 @@ async def worker(runtime: DistributedRuntime) -> None:
async def init(runtime: DistributedRuntime):
# Create component and endpoint
component: Component = runtime.namespace("ns556").component("cp556")
await component.create_service()
endpoint: Endpoint = component.endpoint("ep556")
......
......@@ -67,7 +67,6 @@ async def worker(runtime: DistributedRuntime) -> None:
async def init(runtime: DistributedRuntime):
# Create component and endpoint
component: Component = runtime.namespace("ns557").component("cp557")
await component.create_service()
endpoint: Endpoint = component.endpoint("ep557")
......
......@@ -32,7 +32,6 @@ class RequestHandler:
@dynamo_worker()
async def worker(runtime: DistributedRuntime):
component = runtime.namespace("examples/pipeline").component("backend")
await component.create_service()
endpoint = component.endpoint("generate")
await endpoint.serve_endpoint(RequestHandler().generate)
......
......@@ -44,7 +44,6 @@ async def worker(runtime: DistributedRuntime):
# create endpoint service for frontend component
component = runtime.namespace("examples/pipeline").component("frontend")
await component.create_service()
endpoint = component.endpoint("generate")
......
......@@ -44,7 +44,6 @@ async def worker(runtime: DistributedRuntime):
# create endpoint service for middle component
component = runtime.namespace("examples/pipeline").component("middle")
await component.create_service()
endpoint = component.endpoint("generate")
await endpoint.serve_endpoint(RequestHandler(backend).generate)
......
......@@ -42,7 +42,6 @@ async def worker(runtime: DistributedRuntime):
A `Component` can serve multiple endpoints
"""
component = runtime.namespace("dynamo").component("backend")
await component.create_service()
endpoint = component.endpoint("generate")
await endpoint.serve_endpoint(RequestHandler().generate)
......
......@@ -478,7 +478,12 @@ impl DistributedRuntime {
let runtime_config = DistributedConfig {
store_backend: selected_kv_store,
nats_config: dynamo_runtime::transports::nats::ClientOptions::default(),
// We only need NATS here to monitor it's metrics, so only if it's our request plane.
nats_config: if request_plane.is_nats() {
Some(dynamo_runtime::transports::nats::ClientOptions::default())
} else {
None
},
request_plane,
};
let inner = runtime
......@@ -565,15 +570,6 @@ impl Component {
})
}
/// NATS specific stats/metrics call
fn create_service<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let mut inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
inner.add_stats_service().await.map_err(to_pyerr)?;
Ok(())
})
}
/// Get a RuntimeMetrics helper for creating Prometheus metrics
#[getter]
fn metrics(&self) -> prometheus_metrics::RuntimeMetrics {
......
......@@ -101,12 +101,6 @@ class Component:
...
async def create_service(self) -> None:
"""
Create a service
"""
...
def endpoint(self, name: str) -> Endpoint:
"""
Create an endpoint
......@@ -403,8 +397,7 @@ class WorkerMetricsPublisher:
def create_endpoint(self, component: Component, metrics_labels: Optional[List[Tuple[str, str]]] = None) -> None:
"""
Similar to Component.create_service, but only service created through
this method will interact with KV router of the same component.
Only service created through this method will interact with KV router of the same component.
Args:
component: The component to create the endpoint for
......
......@@ -129,7 +129,6 @@ async def server(runtime, namespace):
async def init_server():
"""Initialize the test server component and serve the generate endpoint"""
component = runtime.namespace(namespace).component("backend")
await component.create_service()
endpoint = component.endpoint("generate")
print("Started test server instance")
......
......@@ -223,7 +223,6 @@ async def test_event_handler(distributed_runtime):
namespace = "kv_test"
component = "event"
kv_listener = distributed_runtime.namespace(namespace).component(component)
await kv_listener.create_service()
# publisher
worker_id = 233
......@@ -281,7 +280,6 @@ async def test_approx_kv_indexer(distributed_runtime):
namespace = "kv_test"
component = "approx_kv"
kv_listener = distributed_runtime.namespace(namespace).component(component)
await kv_listener.create_service()
indexer = ApproxKvIndexer(kv_listener, kv_block_size, 30.0)
......
......@@ -14,7 +14,6 @@ async def get_metrics_runtime(runtime, endpoint_name):
"""Helper to create a unique metrics runtime for each test."""
namespace = runtime.namespace("test_metrics_ns")
component = namespace.component("test_metrics_comp")
await component.create_service()
endpoint = component.endpoint(endpoint_name)
return endpoint.metrics
......
......@@ -16,7 +16,6 @@ TEST_END_TO_END = os.environ.get("TEST_END_TO_END", 0)
@dynamo_worker()
async def test_register(runtime: DistributedRuntime):
component = runtime.namespace("test").component("tensor")
await component.create_service()
endpoint = component.endpoint("generate")
......
......@@ -374,6 +374,7 @@ impl Component {
unimplemented!("collect_stats")
}
// Gather NATS metrics
pub async fn add_stats_service(&mut self) -> anyhow::Result<()> {
let service_name = self.service_name();
......@@ -387,7 +388,10 @@ impl Component {
.services
.contains_key(&service_name)
{
anyhow::bail!("Service {service_name} already exists");
// The NATS service is per component, but it is called from `serve_endpoint`, and there
// are often multiple endpoints for a component (e.g. `clear_kv_blocks` and `generate`).
tracing::trace!("Service {service_name} already exists");
return Ok(());
}
let Some(nats_client) = self.drt.nats_client() else {
......@@ -402,35 +406,24 @@ impl Component {
// Normal case
guard.services.insert(service_name.clone(), nats_service);
guard.stats_handlers.insert(service_name.clone(), stats_reg);
tracing::info!("Added NATS / stats service {service_name}");
drop(guard);
} else {
drop(guard);
let _ = nats_service.stop().await;
return Err(anyhow::anyhow!(
"Service create race for {service_name}, now already exists"
));
// The NATS service is per component, but it is called from `serve_endpoint`, and there
// are often multiple endpoints for a component (e.g. `clear_kv_blocks` and `generate`).
return Ok(());
}
// Register metrics callback. CRITICAL: Never fail service creation for metrics issues.
// Only enable NATS service metrics collection when using NATS request plane mode
let request_plane_mode = self.drt.request_plane();
match request_plane_mode {
RequestPlaneMode::Nats => {
if let Err(err) = self.start_scraping_nats_service_component_metrics() {
tracing::debug!(
"Metrics registration failed for '{}': {}",
self.service_name(),
err
);
}
}
_ => {
tracing::info!(
"Skipping NATS service metrics collection for '{}' - request plane mode is '{}'",
self.service_name(),
request_plane_mode
);
}
if let Err(err) = self.start_scraping_nats_service_component_metrics() {
tracing::debug!(
"Metrics registration failed for '{}': {}",
self.service_name(),
err
);
}
Ok(())
}
......
......@@ -64,7 +64,7 @@ impl EndpointConfigBuilder {
pub async fn start(self) -> Result<()> {
let (
endpoint,
mut endpoint,
handler,
stats_handler,
metrics_labels,
......@@ -86,39 +86,32 @@ impl EndpointConfigBuilder {
// Add metrics to the handler. The endpoint provides additional information to the handler.
handler.add_metrics(&endpoint, metrics_labels.as_deref())?;
let registry = endpoint.drt().component_registry().inner.lock().await;
// Note: NATS service group is no longer needed here as the NetworkManager
// handles all transport-specific initialization internally
let _group = registry
.services
.get(&service_name)
.map(|service| service.group(endpoint.component.service_name()))
.ok_or(anyhow::anyhow!("Service not found"))?;
// get the stats handler map
let handler_map = registry
.stats_handlers
.get(&service_name)
.cloned()
.expect("no stats handler registry; this is unexpected");
drop(registry);
// insert the stats handler
if let Some(stats_handler) = stats_handler {
handler_map
.lock()
.insert(endpoint.subject_to(connection_id), stats_handler);
}
// Determine request plane mode
let request_plane_mode = endpoint.drt().request_plane();
if request_plane_mode.is_nats() {
// We only need the service if we want NATS metrics.
// TODO: This is called for every endpoint of a component. Ideally we only call it once
// on the component.
endpoint.component.add_stats_service().await?;
}
tracing::info!(
"Endpoint starting with request plane mode: {:?}",
request_plane_mode
);
// Insert the stats handler. depends on NATS.
if let Some(stats_handler) = stats_handler {
let registry = endpoint.drt().component_registry().inner.lock().await;
let handler_map = registry
.stats_handlers
.get(&service_name)
.cloned()
.expect("no stats handler registry; this is unexpected");
handler_map
.lock()
.insert(endpoint.subject_to(connection_id), stats_handler);
}
// This creates a child token of the runtime's endpoint_shutdown_token. That token is
// cancelled first as part of graceful shutdown. See Runtime::shutdown.
let endpoint_shutdown_token = endpoint.drt().child_token();
......
......@@ -114,7 +114,10 @@ impl DistributedRuntime {
KeyValueStoreSelect::Memory => (None, KeyValueStoreManager::memory()),
};
let nats_client = Some(nats_config.clone().connect().await?);
let nats_client = match nats_config {
Some(nc) => Some(nc.connect().await?),
None => None,
};
// Start system status server for health and metrics if enabled in configuration
let config = crate::config::RuntimeConfig::from_settings().unwrap_or_default();
......@@ -455,16 +458,21 @@ impl DistributedRuntime {
#[derive(Dissolve)]
pub struct DistributedConfig {
pub store_backend: KeyValueStoreSelect,
pub nats_config: nats::ClientOptions,
pub nats_config: Option<nats::ClientOptions>,
pub request_plane: RequestPlaneMode,
}
impl DistributedConfig {
pub fn from_settings() -> DistributedConfig {
let request_plane = RequestPlaneMode::from_env();
DistributedConfig {
store_backend: KeyValueStoreSelect::Etcd(Box::default()),
nats_config: nats::ClientOptions::default(),
request_plane: RequestPlaneMode::from_env(),
nats_config: if request_plane.is_nats() {
Some(nats::ClientOptions::default())
} else {
None
},
request_plane,
}
}
......@@ -473,10 +481,15 @@ impl DistributedConfig {
attach_lease: false,
..Default::default()
};
let request_plane = RequestPlaneMode::from_env();
DistributedConfig {
store_backend: KeyValueStoreSelect::Etcd(Box::new(etcd_config)),
nats_config: nats::ClientOptions::default(),
request_plane: RequestPlaneMode::from_env(),
nats_config: if request_plane.is_nats() {
Some(nats::ClientOptions::default())
} else {
None
},
request_plane,
}
}
}
......@@ -538,6 +551,10 @@ impl RequestPlaneMode {
.and_then(|s| s.parse().ok())
.unwrap_or_default()
}
pub fn is_nats(&self) -> bool {
matches!(self, RequestPlaneMode::Nats)
}
}
pub mod distributed_test_utils {
......@@ -553,7 +570,7 @@ pub mod distributed_test_utils {
let rt = crate::Runtime::from_current().unwrap();
let config = super::DistributedConfig {
store_backend: KeyValueStoreSelect::Memory,
nats_config: nats::ClientOptions::default(),
nats_config: Some(nats::ClientOptions::default()),
request_plane: crate::distributed::RequestPlaneMode::default(),
};
super::DistributedRuntime::new(rt, config).await.unwrap()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment