Commit 85cc7b67 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

refactor: service/endpoint stats_handler (#282)

parent 0a393dcb
......@@ -67,19 +67,20 @@ async fn backend(runtime: DistributedRuntime) -> Result<()> {
// make the ingress discoverable via a component service
// we must first create a service, then we can attach one more more endpoints
runtime
.namespace(DEFAULT_NAMESPACE)?
.component("backend")?
.service_builder()
// Dummy stats handler to demonstrate how to attach a custom stats handler
.stats_handler(Some(Box::new(|_name, _stats| {
let stats = MyStats { val: 10 };
serde_json::to_value(stats).unwrap()
})))
.create()
.await?
.endpoint("generate")
.endpoint_builder()
.stats_handler(|stats| {
println!("stats: {:?}", stats);
let stats = MyStats { val: 10 };
serde_json::to_value(stats).unwrap()
})
.handler(ingress)
.start()
.await
......
......@@ -78,7 +78,7 @@ impl KvMetricsPublisher {
let rs_component = component.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
rs_publisher
.create_service(rs_component)
.create_endpoint(rs_component)
.await
.map_err(to_pyerr)?;
Ok(())
......
......@@ -14,10 +14,20 @@
// limitations under the License.
use crate::kv_router::{indexer::RouterEvent, protocols::*, KV_EVENT_SUBJECT};
use async_trait::async_trait;
use futures::stream;
use std::sync::Arc;
use tokio::sync::mpsc;
use tracing as log;
use triton_distributed_runtime::{component::Component, DistributedRuntime, Result};
use triton_distributed_runtime::{
component::Component,
pipeline::{
network::Ingress, AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream,
SingleIn,
},
protocols::annotated::Annotated,
DistributedRuntime, Error, Result,
};
pub struct KvEventPublisher {
tx: mpsc::UnboundedSender<KvCacheEvent>,
......@@ -79,17 +89,49 @@ impl KvMetricsPublisher {
self.tx.send(metrics)
}
pub async fn create_service(&self, component: Component) -> Result<()> {
pub async fn create_endpoint(&self, component: Component) -> Result<()> {
let mut metrics_rx = self.rx.clone();
let _ = component
let handler = Arc::new(KvLoadEndpoingHander::new(metrics_rx.clone()));
let handler = Ingress::for_engine(handler)?;
component
.service_builder()
.stats_handler(Some(Box::new(move |name, stats| {
log::debug!("[IN worker?] Stats for service {}: {:?}", name, stats);
.create()
.await?
.endpoint("load_metrics")
.endpoint_builder()
.stats_handler(move |_| {
let metrics = metrics_rx.borrow_and_update().clone();
serde_json::to_value(&*metrics).unwrap()
})))
.create()
.await?;
Ok(())
})
.handler(handler)
.start()
.await
}
}
struct KvLoadEndpoingHander {
metrics_rx: tokio::sync::watch::Receiver<Arc<ForwardPassMetrics>>,
}
impl KvLoadEndpoingHander {
pub fn new(metrics_rx: tokio::sync::watch::Receiver<Arc<ForwardPassMetrics>>) -> Self {
Self { metrics_rx }
}
}
#[async_trait]
impl AsyncEngine<SingleIn<()>, ManyOut<Annotated<ForwardPassMetrics>>, Error>
for KvLoadEndpoingHander
{
async fn generate(
&self,
request: SingleIn<()>,
) -> Result<ManyOut<Annotated<ForwardPassMetrics>>> {
let context = request.context();
let metrics = self.metrics_rx.borrow().clone();
let metrics = (*metrics).clone();
let stream = stream::iter(vec![Annotated::from_data(metrics)]);
Ok(ResponseStream::new(Box::pin(stream), context))
}
}
......@@ -56,6 +56,7 @@ use derive_builder::Builder;
use derive_getters::Getters;
use educe::Educe;
use serde::{Deserialize, Serialize};
use service::EndpointStatsHandler;
use std::{collections::HashMap, sync::Arc};
use validator::{Validate, ValidationError};
......@@ -73,9 +74,15 @@ pub enum TransportType {
NatsTcp(String),
}
#[derive(Default)]
pub struct RegistryInner {
services: HashMap<String, Service>,
stats_handlers: HashMap<String, Arc<std::sync::Mutex<HashMap<String, EndpointStatsHandler>>>>,
}
#[derive(Clone)]
pub struct Registry {
services: Arc<tokio::sync::Mutex<HashMap<String, Service>>>,
inner: Arc<tokio::sync::Mutex<RegistryInner>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
......@@ -111,6 +118,12 @@ pub struct Component {
namespace: String,
}
impl std::fmt::Display for Component {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}.{}", self.namespace, self.name)
}
}
impl DistributedRuntimeProvider for Component {
fn drt(&self) -> &DistributedRuntime {
&self.drt
......
......@@ -33,6 +33,11 @@ pub struct EndpointConfig {
/// Endpoint handler
#[educe(Debug(ignore))]
handler: Arc<dyn PushWorkHandler>,
/// Stats handler
#[educe(Debug(ignore))]
#[builder(default, private)]
_stats_handler: Option<EndpointStatsHandler>,
}
impl EndpointConfigBuilder {
......@@ -40,8 +45,15 @@ impl EndpointConfigBuilder {
Self::default().endpoint(endpoint)
}
pub fn stats_handler<F>(self, handler: F) -> Self
where
F: FnMut(async_nats::service::endpoint::Stats) -> serde_json::Value + Send + Sync + 'static,
{
self._stats_handler(Some(Box::new(handler)))
}
pub async fn start(self) -> Result<()> {
let (endpoint, lease, handler) = self.build_internal()?.dissolve();
let (endpoint, lease, handler, stats_handler) = self.build_internal()?.dissolve();
let lease = lease.unwrap_or(endpoint.drt().primary_lease());
tracing::debug!(
......@@ -49,18 +61,34 @@ impl EndpointConfigBuilder {
endpoint.etcd_path_with_id(lease.id())
);
let group = endpoint
.component
.drt
.component_registry
let service_name = endpoint.component.service_name();
// acquire the registry lock
let registry = endpoint.drt().component_registry.inner.lock().await;
// get the group
let group = registry
.services
.lock()
.await
.get(&endpoint.component.etcd_path())
.get(&service_name)
.map(|service| service.group(endpoint.component.service_name()))
.ok_or(error!("Service not found"))?;
// let group = service.group(service_name.as_str());
// get the stats handler map
let handler_map = registry
.stats_handlers
.get(&service_name)
.cloned()
.expect("no stats handler registry; this is unexpected");
drop(registry);
// insert the stats handler
if let Some(stats_handler) = stats_handler {
handler_map
.lock()
.unwrap()
.insert(endpoint.subject_to(lease.id()), stats_handler);
}
// creates an endpoint for the service
let service_endpoint = group
......@@ -79,8 +107,6 @@ impl EndpointConfigBuilder {
// launch in primary runtime
let task = tokio::spawn(push_endpoint.start(service_endpoint));
// tracing::debug!(worker_id, "endpoint subject: {}", subject);
// make the components service endpoint discovery in etcd
// client.register_service()
......
......@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use super::{Component, Registry, Result};
use super::{Component, Registry, RegistryInner, Result};
use async_once_cell::OnceCell;
use std::{
collections::HashMap,
......@@ -30,7 +30,7 @@ impl Default for Registry {
impl Registry {
pub fn new() -> Self {
Self {
services: Arc::new(Mutex::new(HashMap::new())),
inner: Arc::new(Mutex::new(RegistryInner::default())),
}
}
}
......
......@@ -14,6 +14,8 @@
// limitations under the License.
use derive_getters::Dissolve;
use std::collections::HashMap;
use std::sync::Mutex;
use super::*;
......@@ -22,6 +24,12 @@ use async_nats::service::{endpoint, Service};
pub type StatsHandler =
Box<dyn FnMut(String, endpoint::Stats) -> serde_json::Value + Send + Sync + 'static>;
pub type EndpointStatsHandler =
Box<dyn FnMut(endpoint::Stats) -> serde_json::Value + Send + Sync + 'static>;
// TODO(rename) - pending rename of project
pub const PROJECT_NAME: &str = "Triton";
#[derive(Educe, Builder, Dissolve)]
#[educe(Debug)]
#[builder(pattern = "owned", build_fn(private, name = "build_internal"))]
......@@ -32,56 +40,64 @@ pub struct ServiceConfig {
/// Description
#[builder(default)]
description: Option<String>,
// todo - make optional - if None, then skip making the endpoint
// and skip making the service-endpoint discoverable.
/// Endpoint handler
#[educe(Debug(ignore))]
#[builder(default)]
stats_handler: Option<StatsHandler>,
}
impl ServiceConfigBuilder {
/// Create the [`Component`]'s service and store it in the registry.
pub async fn create(self) -> Result<Component> {
let version = "0.0.1".to_string();
let (component, description) = self.build_internal()?.dissolve();
let (component, description, stat_handler) = self.build_internal()?.dissolve();
let version = "0.0.1".to_string();
let service_name = component.service_name();
log::debug!("component: {component}; creating, service_name: {service_name}");
let description = description.unwrap_or(format!(
"Triton Component {} in {}",
"{PROJECT_NAME} component {} in namespace {}",
component.name, component.namespace
));
let mut guard = component.drt.component_registry.services.lock().await;
let stats_handler_registry: Arc<Mutex<HashMap<String, EndpointStatsHandler>>> =
Arc::new(Mutex::new(HashMap::new()));
let stats_handler_registry_clone = stats_handler_registry.clone();
let mut guard = component.drt.component_registry.inner.lock().await;
if guard.contains_key(&component.etcd_path()) {
if guard.services.contains_key(&service_name) {
return Err(anyhow::anyhow!("Service already exists"));
}
// create service on the secondary runtime
let secondary = component.drt.runtime.secondary();
let builder = component.drt.nats_client.client().service_builder();
let service = secondary
.spawn(async move {
// unwrap the stats handler
let builder = match stat_handler {
Some(handler) => builder.stats_handler(handler),
None => builder,
};
tracing::debug!("Starting service: {}", service_name);
builder
let service = builder
.description(description)
.start(service_name.to_string(), version)
.await
.stats_handler(move |name, stats| {
log::trace!("stats_handler: {name}, {stats:?}");
let mut guard = stats_handler_registry.lock().unwrap();
match guard.get_mut(&name) {
Some(handler) => handler(stats),
None => serde_json::Value::Null,
}
})
.await?
.start(service_name.clone(), version)
.await
.map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?;
guard.insert(component.etcd_path(), service);
// new copy of service_name as the previous one is moved into the task above
let service_name = component.service_name();
// insert the service into the registry
guard.services.insert(service_name.clone(), service);
// insert the stats handler into the registry
guard
.stats_handlers
.insert(service_name, stats_handler_registry_clone);
// drop the guard to unlock the mutex
drop(guard);
Ok(component)
......@@ -93,24 +109,3 @@ impl ServiceConfigBuilder {
Self::default().component(component)
}
}
// // Wrap the optional user callback method in a closure that appends the lease_id to the response
// fn wrap_callback(
// callback: Option<Box<dyn FnMut(String, Stats) -> Value + Send + Sync>>,
// lease_id: i64,
// ) -> Box<dyn FnMut(String, Stats) -> Value + Send + Sync> {
// let callback = Arc::new(Mutex::new(callback)); // Wrap in Arc<Mutex> for shared access
// Box::new(move |subject: String, stats: Stats| -> Value {
// let mut callback_lock = callback.lock().unwrap();
// if let Some(cb) = callback_lock.as_mut() {
// let mut result = cb(subject, stats); // Call the user-defined callback
// if let Some(obj) = result.as_object_mut() {
// obj.insert("lease_id".to_string(), json!(lease_id)); // Append lease_id
// }
// result
// } else {
// json!({ "error": "callback not set", "lease_id": lease_id }) // Default response
// }
// })
// }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment