"...models/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "2b53953096b183f184b6dba2f210c1f92b1666ae"
Commit 85cc7b67 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

refactor: service/endpoint stats_handler (#282)

parent 0a393dcb
...@@ -67,19 +67,20 @@ async fn backend(runtime: DistributedRuntime) -> Result<()> { ...@@ -67,19 +67,20 @@ async fn backend(runtime: DistributedRuntime) -> Result<()> {
// make the ingress discoverable via a component service // make the ingress discoverable via a component service
// we must first create a service, then we can attach one more more endpoints // we must first create a service, then we can attach one more more endpoints
runtime runtime
.namespace(DEFAULT_NAMESPACE)? .namespace(DEFAULT_NAMESPACE)?
.component("backend")? .component("backend")?
.service_builder() .service_builder()
// Dummy stats handler to demonstrate how to attach a custom stats handler
.stats_handler(Some(Box::new(|_name, _stats| {
let stats = MyStats { val: 10 };
serde_json::to_value(stats).unwrap()
})))
.create() .create()
.await? .await?
.endpoint("generate") .endpoint("generate")
.endpoint_builder() .endpoint_builder()
.stats_handler(|stats| {
println!("stats: {:?}", stats);
let stats = MyStats { val: 10 };
serde_json::to_value(stats).unwrap()
})
.handler(ingress) .handler(ingress)
.start() .start()
.await .await
......
...@@ -78,7 +78,7 @@ impl KvMetricsPublisher { ...@@ -78,7 +78,7 @@ impl KvMetricsPublisher {
let rs_component = component.inner.clone(); let rs_component = component.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
rs_publisher rs_publisher
.create_service(rs_component) .create_endpoint(rs_component)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
Ok(()) Ok(())
......
...@@ -14,10 +14,20 @@ ...@@ -14,10 +14,20 @@
// limitations under the License. // limitations under the License.
use crate::kv_router::{indexer::RouterEvent, protocols::*, KV_EVENT_SUBJECT}; use crate::kv_router::{indexer::RouterEvent, protocols::*, KV_EVENT_SUBJECT};
use async_trait::async_trait;
use futures::stream;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tracing as log; use tracing as log;
use triton_distributed_runtime::{component::Component, DistributedRuntime, Result}; use triton_distributed_runtime::{
component::Component,
pipeline::{
network::Ingress, AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream,
SingleIn,
},
protocols::annotated::Annotated,
DistributedRuntime, Error, Result,
};
pub struct KvEventPublisher { pub struct KvEventPublisher {
tx: mpsc::UnboundedSender<KvCacheEvent>, tx: mpsc::UnboundedSender<KvCacheEvent>,
...@@ -79,17 +89,49 @@ impl KvMetricsPublisher { ...@@ -79,17 +89,49 @@ impl KvMetricsPublisher {
self.tx.send(metrics) self.tx.send(metrics)
} }
pub async fn create_service(&self, component: Component) -> Result<()> { pub async fn create_endpoint(&self, component: Component) -> Result<()> {
let mut metrics_rx = self.rx.clone(); let mut metrics_rx = self.rx.clone();
let _ = component let handler = Arc::new(KvLoadEndpoingHander::new(metrics_rx.clone()));
let handler = Ingress::for_engine(handler)?;
component
.service_builder() .service_builder()
.stats_handler(Some(Box::new(move |name, stats| { .create()
log::debug!("[IN worker?] Stats for service {}: {:?}", name, stats); .await?
.endpoint("load_metrics")
.endpoint_builder()
.stats_handler(move |_| {
let metrics = metrics_rx.borrow_and_update().clone(); let metrics = metrics_rx.borrow_and_update().clone();
serde_json::to_value(&*metrics).unwrap() serde_json::to_value(&*metrics).unwrap()
}))) })
.create() .handler(handler)
.await?; .start()
Ok(()) .await
}
}
struct KvLoadEndpoingHander {
metrics_rx: tokio::sync::watch::Receiver<Arc<ForwardPassMetrics>>,
}
impl KvLoadEndpoingHander {
pub fn new(metrics_rx: tokio::sync::watch::Receiver<Arc<ForwardPassMetrics>>) -> Self {
Self { metrics_rx }
}
}
#[async_trait]
impl AsyncEngine<SingleIn<()>, ManyOut<Annotated<ForwardPassMetrics>>, Error>
for KvLoadEndpoingHander
{
async fn generate(
&self,
request: SingleIn<()>,
) -> Result<ManyOut<Annotated<ForwardPassMetrics>>> {
let context = request.context();
let metrics = self.metrics_rx.borrow().clone();
let metrics = (*metrics).clone();
let stream = stream::iter(vec![Annotated::from_data(metrics)]);
Ok(ResponseStream::new(Box::pin(stream), context))
} }
} }
...@@ -56,6 +56,7 @@ use derive_builder::Builder; ...@@ -56,6 +56,7 @@ use derive_builder::Builder;
use derive_getters::Getters; use derive_getters::Getters;
use educe::Educe; use educe::Educe;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use service::EndpointStatsHandler;
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use validator::{Validate, ValidationError}; use validator::{Validate, ValidationError};
...@@ -73,9 +74,15 @@ pub enum TransportType { ...@@ -73,9 +74,15 @@ pub enum TransportType {
NatsTcp(String), NatsTcp(String),
} }
#[derive(Default)]
pub struct RegistryInner {
services: HashMap<String, Service>,
stats_handlers: HashMap<String, Arc<std::sync::Mutex<HashMap<String, EndpointStatsHandler>>>>,
}
#[derive(Clone)] #[derive(Clone)]
pub struct Registry { pub struct Registry {
services: Arc<tokio::sync::Mutex<HashMap<String, Service>>>, inner: Arc<tokio::sync::Mutex<RegistryInner>>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
...@@ -111,6 +118,12 @@ pub struct Component { ...@@ -111,6 +118,12 @@ pub struct Component {
namespace: String, namespace: String,
} }
impl std::fmt::Display for Component {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}.{}", self.namespace, self.name)
}
}
impl DistributedRuntimeProvider for Component { impl DistributedRuntimeProvider for Component {
fn drt(&self) -> &DistributedRuntime { fn drt(&self) -> &DistributedRuntime {
&self.drt &self.drt
......
...@@ -33,6 +33,11 @@ pub struct EndpointConfig { ...@@ -33,6 +33,11 @@ pub struct EndpointConfig {
/// Endpoint handler /// Endpoint handler
#[educe(Debug(ignore))] #[educe(Debug(ignore))]
handler: Arc<dyn PushWorkHandler>, handler: Arc<dyn PushWorkHandler>,
/// Stats handler
#[educe(Debug(ignore))]
#[builder(default, private)]
_stats_handler: Option<EndpointStatsHandler>,
} }
impl EndpointConfigBuilder { impl EndpointConfigBuilder {
...@@ -40,8 +45,15 @@ impl EndpointConfigBuilder { ...@@ -40,8 +45,15 @@ impl EndpointConfigBuilder {
Self::default().endpoint(endpoint) Self::default().endpoint(endpoint)
} }
pub fn stats_handler<F>(self, handler: F) -> Self
where
F: FnMut(async_nats::service::endpoint::Stats) -> serde_json::Value + Send + Sync + 'static,
{
self._stats_handler(Some(Box::new(handler)))
}
pub async fn start(self) -> Result<()> { pub async fn start(self) -> Result<()> {
let (endpoint, lease, handler) = self.build_internal()?.dissolve(); let (endpoint, lease, handler, stats_handler) = self.build_internal()?.dissolve();
let lease = lease.unwrap_or(endpoint.drt().primary_lease()); let lease = lease.unwrap_or(endpoint.drt().primary_lease());
tracing::debug!( tracing::debug!(
...@@ -49,18 +61,34 @@ impl EndpointConfigBuilder { ...@@ -49,18 +61,34 @@ impl EndpointConfigBuilder {
endpoint.etcd_path_with_id(lease.id()) endpoint.etcd_path_with_id(lease.id())
); );
let group = endpoint let service_name = endpoint.component.service_name();
.component
.drt // acquire the registry lock
.component_registry let registry = endpoint.drt().component_registry.inner.lock().await;
// get the group
let group = registry
.services .services
.lock() .get(&service_name)
.await
.get(&endpoint.component.etcd_path())
.map(|service| service.group(endpoint.component.service_name())) .map(|service| service.group(endpoint.component.service_name()))
.ok_or(error!("Service not found"))?; .ok_or(error!("Service not found"))?;
// let group = service.group(service_name.as_str()); // get the stats handler map
let handler_map = registry
.stats_handlers
.get(&service_name)
.cloned()
.expect("no stats handler registry; this is unexpected");
drop(registry);
// insert the stats handler
if let Some(stats_handler) = stats_handler {
handler_map
.lock()
.unwrap()
.insert(endpoint.subject_to(lease.id()), stats_handler);
}
// creates an endpoint for the service // creates an endpoint for the service
let service_endpoint = group let service_endpoint = group
...@@ -79,8 +107,6 @@ impl EndpointConfigBuilder { ...@@ -79,8 +107,6 @@ impl EndpointConfigBuilder {
// launch in primary runtime // launch in primary runtime
let task = tokio::spawn(push_endpoint.start(service_endpoint)); let task = tokio::spawn(push_endpoint.start(service_endpoint));
// tracing::debug!(worker_id, "endpoint subject: {}", subject);
// make the components service endpoint discovery in etcd // make the components service endpoint discovery in etcd
// client.register_service() // client.register_service()
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use super::{Component, Registry, Result}; use super::{Component, Registry, RegistryInner, Result};
use async_once_cell::OnceCell; use async_once_cell::OnceCell;
use std::{ use std::{
collections::HashMap, collections::HashMap,
...@@ -30,7 +30,7 @@ impl Default for Registry { ...@@ -30,7 +30,7 @@ impl Default for Registry {
impl Registry { impl Registry {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
services: Arc::new(Mutex::new(HashMap::new())), inner: Arc::new(Mutex::new(RegistryInner::default())),
} }
} }
} }
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
// limitations under the License. // limitations under the License.
use derive_getters::Dissolve; use derive_getters::Dissolve;
use std::collections::HashMap;
use std::sync::Mutex;
use super::*; use super::*;
...@@ -22,6 +24,12 @@ use async_nats::service::{endpoint, Service}; ...@@ -22,6 +24,12 @@ use async_nats::service::{endpoint, Service};
pub type StatsHandler = pub type StatsHandler =
Box<dyn FnMut(String, endpoint::Stats) -> serde_json::Value + Send + Sync + 'static>; Box<dyn FnMut(String, endpoint::Stats) -> serde_json::Value + Send + Sync + 'static>;
pub type EndpointStatsHandler =
Box<dyn FnMut(endpoint::Stats) -> serde_json::Value + Send + Sync + 'static>;
// TODO(rename) - pending rename of project
pub const PROJECT_NAME: &str = "Triton";
#[derive(Educe, Builder, Dissolve)] #[derive(Educe, Builder, Dissolve)]
#[educe(Debug)] #[educe(Debug)]
#[builder(pattern = "owned", build_fn(private, name = "build_internal"))] #[builder(pattern = "owned", build_fn(private, name = "build_internal"))]
...@@ -32,56 +40,64 @@ pub struct ServiceConfig { ...@@ -32,56 +40,64 @@ pub struct ServiceConfig {
/// Description /// Description
#[builder(default)] #[builder(default)]
description: Option<String>, description: Option<String>,
// todo - make optional - if None, then skip making the endpoint
// and skip making the service-endpoint discoverable.
/// Endpoint handler
#[educe(Debug(ignore))]
#[builder(default)]
stats_handler: Option<StatsHandler>,
} }
impl ServiceConfigBuilder { impl ServiceConfigBuilder {
/// Create the [`Component`]'s service and store it in the registry. /// Create the [`Component`]'s service and store it in the registry.
pub async fn create(self) -> Result<Component> { pub async fn create(self) -> Result<Component> {
let version = "0.0.1".to_string(); let (component, description) = self.build_internal()?.dissolve();
let (component, description, stat_handler) = self.build_internal()?.dissolve(); let version = "0.0.1".to_string();
let service_name = component.service_name(); let service_name = component.service_name();
log::debug!("component: {component}; creating, service_name: {service_name}");
let description = description.unwrap_or(format!( let description = description.unwrap_or(format!(
"Triton Component {} in {}", "{PROJECT_NAME} component {} in namespace {}",
component.name, component.namespace component.name, component.namespace
)); ));
let mut guard = component.drt.component_registry.services.lock().await; let stats_handler_registry: Arc<Mutex<HashMap<String, EndpointStatsHandler>>> =
Arc::new(Mutex::new(HashMap::new()));
let stats_handler_registry_clone = stats_handler_registry.clone();
let mut guard = component.drt.component_registry.inner.lock().await;
if guard.contains_key(&component.etcd_path()) { if guard.services.contains_key(&service_name) {
return Err(anyhow::anyhow!("Service already exists")); return Err(anyhow::anyhow!("Service already exists"));
} }
// create service on the secondary runtime // create service on the secondary runtime
let secondary = component.drt.runtime.secondary();
let builder = component.drt.nats_client.client().service_builder(); let builder = component.drt.nats_client.client().service_builder();
let service = secondary
.spawn(async move {
// unwrap the stats handler
let builder = match stat_handler {
Some(handler) => builder.stats_handler(handler),
None => builder,
};
tracing::debug!("Starting service: {}", service_name); tracing::debug!("Starting service: {}", service_name);
let service = builder
builder
.description(description) .description(description)
.start(service_name.to_string(), version) .stats_handler(move |name, stats| {
.await log::trace!("stats_handler: {name}, {stats:?}");
let mut guard = stats_handler_registry.lock().unwrap();
match guard.get_mut(&name) {
Some(handler) => handler(stats),
None => serde_json::Value::Null,
}
}) })
.await? .start(service_name.clone(), version)
.await
.map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?; .map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?;
guard.insert(component.etcd_path(), service); // new copy of service_name as the previous one is moved into the task above
let service_name = component.service_name();
// insert the service into the registry
guard.services.insert(service_name.clone(), service);
// insert the stats handler into the registry
guard
.stats_handlers
.insert(service_name, stats_handler_registry_clone);
// drop the guard to unlock the mutex
drop(guard); drop(guard);
Ok(component) Ok(component)
...@@ -93,24 +109,3 @@ impl ServiceConfigBuilder { ...@@ -93,24 +109,3 @@ impl ServiceConfigBuilder {
Self::default().component(component) Self::default().component(component)
} }
} }
// // Wrap the optional user callback method in a closure that appends the lease_id to the response
// fn wrap_callback(
// callback: Option<Box<dyn FnMut(String, Stats) -> Value + Send + Sync>>,
// lease_id: i64,
// ) -> Box<dyn FnMut(String, Stats) -> Value + Send + Sync> {
// let callback = Arc::new(Mutex::new(callback)); // Wrap in Arc<Mutex> for shared access
// Box::new(move |subject: String, stats: Stats| -> Value {
// let mut callback_lock = callback.lock().unwrap();
// if let Some(cb) = callback_lock.as_mut() {
// let mut result = cb(subject, stats); // Call the user-defined callback
// if let Some(obj) = result.as_object_mut() {
// obj.insert("lease_id".to_string(), json!(lease_id)); // Append lease_id
// }
// result
// } else {
// json!({ "error": "callback not set", "lease_id": lease_id }) // Default response
// }
// })
// }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment