Commit 09656f6c authored by Ryan McCormick's avatar Ryan McCormick Committed by GitHub
Browse files

feat: Add estimated kv cache hit metric events (#30)

parent a720fa12
......@@ -53,6 +53,8 @@ tensorrtllm_checkpoints/
tensorrtllm_engines/
api_server_models/
server/
# Replay/Snapshot test artifacts
*.new
**/*backups*
......@@ -75,4 +77,4 @@ __pycache__/
*$py.class
*.so
**/.devcontainer
\ No newline at end of file
**/.devcontainer
......@@ -739,6 +739,7 @@ dependencies = [
"clap",
"dynemo-llm",
"dynemo-runtime",
"futures",
"opentelemetry",
"opentelemetry-prometheus",
"prometheus",
......
......@@ -38,6 +38,7 @@ opentelemetry-prometheus = "0.13"
prometheus = "0.13"
rand = "0.8"
axum = "0.6"
futures = "0.3"
[dev-dependencies]
reqwest = { version = "0.11", features = ["blocking"] }
......@@ -16,7 +16,7 @@
//! Library functions for the count application.
use axum::{routing::get, Router};
use prometheus::register_gauge_vec;
use prometheus::{register_counter_vec, register_gauge_vec};
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
......@@ -97,6 +97,18 @@ impl PrometheusMetricsServer {
pub fn update(&mut self, config: &LLMWorkerLoadCapacityConfig, processed: &ProcessedEndpoints) {
self.metrics.update(config, processed);
}
/// Update KV hit rate metrics
pub fn update_kv_hit_rate(
&mut self,
config: &LLMWorkerLoadCapacityConfig,
worker_id: i64,
isl_blocks: usize,
overlap_blocks: usize,
) {
self.metrics
.update_kv_hit_rate(config, worker_id, isl_blocks, overlap_blocks);
}
}
/// Prometheus metrics collection
......@@ -107,6 +119,9 @@ pub struct PrometheusMetrics {
requests_total: prometheus::GaugeVec,
load_avg: prometheus::GaugeVec,
load_std: prometheus::GaugeVec,
// KV hit rate metrics
kv_hit_rate_isl_blocks: prometheus::CounterVec,
kv_hit_rate_overlap_blocks: prometheus::CounterVec,
}
impl PrometheusMetrics {
......@@ -143,6 +158,19 @@ impl PrometheusMetrics {
"Load standard deviation across workers",
&["component", "endpoint"]
)?,
// TODO: The cumulative isl/overlap metrics are monotonically increasing
// and may overflow at some point, we may want to periodically reset them.
// KV hit rate metrics
kv_hit_rate_isl_blocks: register_counter_vec!(
"llm_kv_hit_rate_isl_blocks",
"Cumulative count of ISL blocks in KV hit rate events",
&["component", "endpoint", "worker_id"]
)?,
kv_hit_rate_overlap_blocks: register_counter_vec!(
"llm_kv_hit_rate_overlap_blocks",
"Cumulative count of overlapping blocks in KV hit rate events",
&["component", "endpoint", "worker_id"]
)?,
})
}
......@@ -159,6 +187,19 @@ impl PrometheusMetrics {
.set(value);
}
/// Helper method to increment a counter with worker-specific labels (3 labels)
fn increment_worker_counter(
&self,
counter: &prometheus::CounterVec,
config: &LLMWorkerLoadCapacityConfig,
worker_id: &str,
value: f64,
) {
counter
.with_label_values(&[&config.component_name, &config.endpoint_name, worker_id])
.inc_by(value);
}
/// Helper method to set a gauge with component/endpoint labels only (2 labels)
fn set_endpoint_gauge(
&self,
......@@ -208,6 +249,61 @@ impl PrometheusMetrics {
self.set_endpoint_gauge(&self.load_avg, config, processed.load_avg);
self.set_endpoint_gauge(&self.load_std, config, processed.load_std);
}
/// Update KV hit rate metrics
pub fn update_kv_hit_rate(
&self,
config: &LLMWorkerLoadCapacityConfig,
worker_id: i64,
isl_blocks: usize,
overlap_blocks: usize,
) {
let worker_id_str = worker_id.to_string();
// Increment the ISL blocks and overlap blocks counters
self.increment_worker_counter(
&self.kv_hit_rate_isl_blocks,
config,
&worker_id_str,
isl_blocks as f64,
);
self.increment_worker_counter(
&self.kv_hit_rate_overlap_blocks,
config,
&worker_id_str,
overlap_blocks as f64,
);
// TODO: The cumulative hit rate percentage can probably be computed by consumers
// of Prometheus metrics like Grafana instead, but we'll compute it here for now
// for convenient debugging/logging.
// Calculate and set the cumulative hit rate percentage
let cumulative_isl = self
.kv_hit_rate_isl_blocks
.with_label_values(&[
&config.component_name,
&config.endpoint_name,
&worker_id_str,
])
.get();
let cumulative_overlap = self
.kv_hit_rate_overlap_blocks
.with_label_values(&[
&config.component_name,
&config.endpoint_name,
&worker_id_str,
])
.get();
if cumulative_isl > 0.0 {
let cumulative_hit_rate = (cumulative_overlap / cumulative_isl) * 100.0;
tracing::info!(
"Estimated Cumulative KV hit rate: {cumulative_hit_rate:.2}% (Overlap: {cumulative_overlap} / ISL: {cumulative_isl})"
);
}
}
}
/// Collect endpoints from a component
......
......@@ -22,14 +22,20 @@
//! - These metrics will be scraped by the LLM NATS Service API's stats request
//! - Request Slots: [Active, Total]
//! - KV Cache Blocks: [Active, Total]
//! - KV Hit Rate:
//! - These metrics will be collected from KV hit rate events published by the KV router
//! - ISL Blocks: Cumulative count of total blocks in all KV hit rate events
//! - Overlap Blocks: Cumulative count of blocks that were already in the KV cache
use clap::Parser;
use dynemo_llm::kv_router::scheduler::KVHitRateEvent;
use dynemo_runtime::{
error, logging,
traits::events::EventPublisher,
traits::events::{EventPublisher, EventSubscriber},
utils::{Duration, Instant},
DistributedRuntime, ErrorContext, Result, Runtime, Worker,
};
use futures::stream::StreamExt;
use std::sync::Arc;
// Import from our library
use count::{
......@@ -111,8 +117,65 @@ async fn app(runtime: Runtime) -> Result<()> {
// TODO: Make metrics host/port configurable
// Initialize Prometheus metrics and start server
let mut metrics_server = PrometheusMetricsServer::new()?;
metrics_server.start(9091);
let metrics_server = PrometheusMetricsServer::new()?;
// Metrics will be updated concurrently, so protect it with a mutex:
// - Main loop: Collect and process ForwardPassMetrics at an interval from endpoint stats handlers
// - Subscription task: Collect and process KVHitRateEvent metrics from the KV router as they are published
let metrics_server = Arc::new(tokio::sync::Mutex::new(metrics_server));
metrics_server.lock().await.start(9091);
// Subscribe to KV hit rate events
let kv_hit_rate_subject = "kv-hit-rate";
tracing::info!("Subscribing to KV hit rate events on subject: {kv_hit_rate_subject}");
// Clone the metrics server and config for the subscription task
let metrics_server_clone = metrics_server.clone();
let config_clone = config.clone();
// Clone the namespace for the subscription task
let namespace_clone = namespace.clone();
// Spawn a task to handle KV hit rate events
tokio::spawn(async move {
match namespace_clone.subscribe(kv_hit_rate_subject).await {
Ok(mut subscriber) => {
tracing::info!("Successfully subscribed to KV hit rate events");
while let Some(msg) = subscriber.next().await {
match serde_json::from_slice::<KVHitRateEvent>(&msg.payload) {
Ok(event) => {
// TODO: Lower to debug
let cache_hit_pct =
(event.overlap_blocks as f64 / event.isl_blocks as f64) * 100.0;
tracing::info!(
"Received KV hit rate event: worker_id={}, isl_blocks={}, overlap_blocks={}, cache_hit_pct={:.2}%",
event.worker_id,
event.isl_blocks,
event.overlap_blocks,
cache_hit_pct
);
// Update metrics with the event data
let mut metrics = metrics_server_clone.lock().await;
metrics.update_kv_hit_rate(
&config_clone,
event.worker_id,
event.isl_blocks,
event.overlap_blocks,
);
}
Err(e) => {
tracing::warn!("Failed to deserialize KV hit rate event: {:?}", e);
}
}
}
tracing::warn!("KV hit rate event subscription stream ended");
}
Err(e) => {
tracing::error!("Failed to subscribe to KV hit rate events: {:?}", e);
}
}
});
loop {
let next = Instant::now() + Duration::from_secs(args.poll_interval);
......@@ -123,12 +186,14 @@ async fn app(runtime: Runtime) -> Result<()> {
collect_endpoints(&target_component, &service_subject, scrape_timeout).await?;
let metrics = extract_metrics(&endpoints);
let processed = postprocess_metrics(&metrics, &endpoints);
tracing::info!("Aggregated metrics: {processed:?}");
tracing::debug!("Aggregated metrics: {processed:?}");
// Update Prometheus metrics
metrics_server.update(&config, &processed);
metrics_server.lock().await.update(&config, &processed);
// TODO: Who needs to consume these events?
// TODO: Enable KV Routers to subscribe to metrics events published here
// for a single view of the aggregated metrics, as opposed to the current
// approach where each KV Router computes and published its own metrics.
// Publish metrics event
namespace.publish(&event_name, &processed).await?;
......
......@@ -16,17 +16,17 @@
]
},
"copyright": [
"SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.",
"SPDX-License-Identifier: Apache-2.0",
"Licensed under the Apache License, Version 2.0 (the \"License\");",
"you may not use this file except in compliance with the License.",
"You may obtain a copy of the License at",
"http://www.apache.org/licenses/LICENSE-2.0",
"Unless required by applicable law or agreed to in writing, software",
"distributed under the License is distributed on an \"AS IS\" BASIS,",
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.",
"See the License for the specific language governing permissions and",
"limitations under the License."
"SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.",
"SPDX-License-Identifier: Apache-2.0",
"Licensed under the Apache License, Version 2.0 (the \"License\");",
"you may not use this file except in compliance with the License.",
"You may obtain a copy of the License at",
"http://www.apache.org/licenses/LICENSE-2.0",
"Unless required by applicable law or agreed to in writing, software",
"distributed under the License is distributed on an \"AS IS\" BASIS,",
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.",
"See the License for the specific language governing permissions and",
"limitations under the License."
],
"editable": true,
"fiscalYearStartMonth": 0,
......@@ -261,7 +261,7 @@
},
"gridPos": {
"h": 8,
"w": 6,
"w": 4,
"x": 0,
"y": 8
},
......@@ -329,8 +329,8 @@
},
"gridPos": {
"h": 8,
"w": 6,
"x": 6,
"w": 4,
"x": 4,
"y": 8
},
"id": 4,
......@@ -363,6 +363,74 @@
}
]
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "thresholds"
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "yellow",
"value": 50
},
{
"color": "red",
"value": 80
}
]
},
"unit": "percent"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 4,
"x": 8,
"y": 8
},
"id": 7,
"options": {
"orientation": "auto",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"showThresholdLabels": false,
"showThresholdMarkers": true
},
"pluginVersion": "10.0.0",
"title": "Cumulative KV Cache Hit Rate",
"type": "gauge",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "100 * sum(llm_kv_hit_rate_overlap_blocks{component=\"$component\", endpoint=\"$endpoint\"}) / sum(llm_kv_hit_rate_isl_blocks{component=\"$component\", endpoint=\"$endpoint\"})",
"legendFormat": "__auto",
"range": true,
"refId": "A"
}
]
},
{
"datasource": {
"type": "prometheus",
......@@ -467,6 +535,190 @@
}
]
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 20,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"lineInterpolation": "smooth",
"lineWidth": 2,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
}
]
},
"unit": "percent"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 16
},
"id": 8,
"options": {
"legend": {
"calcs": [
"mean",
"max"
],
"displayMode": "table",
"placement": "right",
"showLegend": true
},
"tooltip": {
"mode": "multi",
"sort": "none"
}
},
"title": "KV Cache Hit Rate by Worker",
"type": "timeseries",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "100 * llm_kv_hit_rate_overlap_blocks{component=\"$component\", endpoint=\"$endpoint\"} / llm_kv_hit_rate_isl_blocks{component=\"$component\", endpoint=\"$endpoint\"}",
"legendFormat": "Worker {{worker_id}}",
"range": true,
"refId": "A"
}
]
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 20,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"lineInterpolation": "smooth",
"lineWidth": 2,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
}
]
},
"unit": "percent"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 16
},
"id": 9,
"options": {
"legend": {
"calcs": [
"mean",
"max"
],
"displayMode": "table",
"placement": "right",
"showLegend": true
},
"tooltip": {
"mode": "multi",
"sort": "none"
}
},
"title": "Cumulative KV Cache Hit Rate",
"type": "timeseries",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"editorMode": "code",
"expr": "100 * sum(llm_kv_hit_rate_overlap_blocks{component=\"$component\", endpoint=\"$endpoint\"}) / sum(llm_kv_hit_rate_isl_blocks{component=\"$component\", endpoint=\"$endpoint\"})",
"legendFormat": "Overall Hit Rate",
"range": true,
"refId": "A"
}
]
},
{
"datasource": {
"type": "prometheus",
......@@ -525,7 +777,7 @@
"h": 8,
"w": 24,
"x": 0,
"y": 16
"y": 24
},
"id": 6,
"options": {
......@@ -647,4 +899,4 @@
"uid": "llm-worker-metrics",
"version": 1,
"weekStart": ""
}
}
\ No newline at end of file
......@@ -33,7 +33,7 @@ ENDPOINT_NAME=${4:-"dynemo.process.chat/completions"}
VALID_STRATEGIES=("prefix")
SESSION_NAME="v"
WORKDIR="/workspace/examples/python_rs/llm/vllm"
INIT_CMD="source /opt/dynemo/venv/bin/activate && cd $WORKDIR"
INIT_CMD="cd $WORKDIR"
if [[ ! " ${VALID_STRATEGIES[@]} " =~ " ${ROUTING_STRATEGY} " ]]; then
echo "Error: Invalid routing strategy. Must be one of: ${VALID_STRATEGIES[*]}"
......
......@@ -14,7 +14,7 @@
// limitations under the License.
use anyhow::Result;
use dynemo_runtime::{component::Component, DistributedRuntime};
use dynemo_runtime::{component::Component, component::Namespace, DistributedRuntime};
use futures::stream::StreamExt;
use std::{sync::Arc, time::Duration};
use tokio_util::sync::CancellationToken;
......@@ -57,15 +57,19 @@ impl KvRouter {
let nats_client = runtime.nats_client();
let service_name = backend.service_name();
let kv_subject = backend.event_subject(KV_EVENT_SUBJECT);
let namespace = runtime.namespace(backend.namespace())?;
tracing::info!("Component Namespace {}", backend.namespace());
tracing::info!("Component Service Name {}", service_name);
tracing::info!("KV Subject {}", kv_subject);
Self::new(nats_client, service_name, kv_subject).await
Self::new(nats_client, service_name, kv_subject, namespace).await
}
pub async fn new(
nats_client: dynemo_runtime::transports::nats::Client,
service_name: String,
kv_subject: String,
namespace: Namespace,
) -> Result<Arc<Self>> {
let cancellation_token = CancellationToken::new();
let (ep_tx, ep_rx) = tokio::sync::mpsc::channel(128);
......@@ -78,7 +82,7 @@ impl KvRouter {
));
let indexer = KvIndexer::new(cancellation_token.clone());
let scheduler = KvScheduler::start(ep_rx).await?;
let scheduler = KvScheduler::start(ep_rx, namespace).await?;
tracing::debug!("subscribing to kv events: {}", kv_subject);
let mut kv_events_rx = nats_client.client().subscribe(kv_subject).await?;
......
......@@ -13,6 +13,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use dynemo_runtime::component::Namespace;
use dynemo_runtime::traits::events::EventPublisher;
use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut;
use std::cmp::min;
......@@ -21,7 +23,13 @@ use crate::kv_router::indexer::OverlapScores;
pub use crate::kv_router::protocols::{ForwardPassMetrics, KV_BLOCK_SIZE};
use crate::kv_router::scoring::ProcessedEndpoints;
#[allow(dead_code)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KVHitRateEvent {
pub worker_id: i64,
pub isl_blocks: usize,
pub overlap_blocks: usize,
}
#[derive(Debug, thiserror::Error)]
pub enum KvSchedulerError {
#[error("no endpoints aviailable to route work")]
......@@ -93,6 +101,7 @@ pub struct KvScheduler {
impl KvScheduler {
pub async fn start(
endpoints_rx: tokio::sync::mpsc::Receiver<ProcessedEndpoints>,
ns: Namespace,
) -> Result<Self, KvSchedulerError> {
let mut endpoints_rx = endpoints_rx;
......@@ -104,6 +113,20 @@ impl KvScheduler {
}
};
// Channel to asynchronously publish metric events on
let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<KVHitRateEvent>();
// Publisher task
tokio::spawn(async move {
let mut event_rx = event_rx;
let subject = "kv-hit-rate";
while let Some(event) = event_rx.recv().await {
if let Err(e) = ns.publish(subject, &event).await {
tracing::warn!("Failed to publish KV hit rate event: {:?}", e);
}
}
});
// Channel to accept new scheduling requests
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(16);
tracing::debug!("scheduler starting");
......@@ -146,7 +169,7 @@ impl KvScheduler {
};
tracing::debug!("selected");
loop {
match select_worker(endpoints.borrow_mut(), &request) {
match select_worker(endpoints.borrow_mut(), &request, &event_tx) {
Ok(worker_id) => {
request.respond(worker_id);
continue 'outer;
......@@ -175,7 +198,6 @@ impl KvScheduler {
Ok(KvScheduler { request_tx })
}
#[allow(dead_code)]
pub async fn schedule(
&self,
overlap: OverlapScores,
......@@ -205,6 +227,7 @@ impl KvScheduler {
pub fn select_worker(
workers: &mut ProcessedEndpoints,
request: &SchedulingRequest,
event_tx: &tokio::sync::mpsc::UnboundedSender<KVHitRateEvent>,
) -> Result<i64, KvSchedulerError> {
// balance mode prioritizes balancing load across workers
let balance_threshold: f64 = 0.1;
......@@ -268,6 +291,23 @@ pub fn select_worker(
workers.endpoints[best_index].data.request_active_slots += 1;
workers.endpoints[best_index].data.kv_active_blocks += total_blocks as u64;
// Optimization - pass this to a channel for emitting events, async task, etc. to avoid blocking the scheduler
let best_worker_id = workers.endpoints[best_index].worker_id();
let isl_blocks = request.isl_tokens / KV_BLOCK_SIZE;
let overlap_blocks = request
.overlap
.scores
.get(&best_worker_id)
.copied()
.unwrap_or(0);
if let Err(e) = event_tx.send(KVHitRateEvent {
worker_id: best_worker_id,
isl_blocks,
overlap_blocks: overlap_blocks as usize,
}) {
tracing::warn!("Failed to send KV hit rate event: {:?}", e);
}
}
match best_index {
......
......@@ -154,6 +154,11 @@ impl Component {
format!("{}/{}", self.namespace, self.name)
}
/// Returns a reference to the namespace string of this component
pub fn namespace(&self) -> &str {
&self.namespace
}
pub fn drt(&self) -> &DistributedRuntime {
&self.drt
}
......
......@@ -14,10 +14,12 @@
// limitations under the License.
use async_trait::async_trait;
use futures::stream::StreamExt;
use futures::{Stream, TryStreamExt};
use super::*;
use crate::traits::events::EventPublisher;
use crate::traits::events::{EventPublisher, EventSubscriber};
#[async_trait]
impl EventPublisher for Namespace {
......@@ -49,6 +51,32 @@ impl EventPublisher for Namespace {
}
}
#[async_trait]
impl EventSubscriber for Namespace {
async fn subscribe(
&self,
event_name: impl AsRef<str> + Send + Sync,
) -> Result<async_nats::Subscriber> {
let subject = format!("{}.{}", self.subject(), event_name.as_ref());
Ok(self.drt().nats_client().client().subscribe(subject).await?)
}
async fn subscribe_with_type<T: for<'de> Deserialize<'de> + Send + 'static>(
&self,
event_name: impl AsRef<str> + Send + Sync,
) -> Result<impl Stream<Item = Result<T>> + Send> {
let subscriber = self.subscribe(event_name).await?;
// Transform the subscriber into a stream of deserialized events
let stream = subscriber.map(move |msg| {
serde_json::from_slice::<T>(&msg.payload)
.map_err(|e| anyhow::anyhow!("Failed to deserialize event: {}", e))
});
Ok(stream)
}
}
#[cfg(feature = "integration")]
#[cfg(test)]
mod tests {
......@@ -64,4 +92,27 @@ mod tests {
ns.publish("test", &"test".to_string()).await.unwrap();
rt.shutdown();
}
#[tokio::test]
async fn test_subscribe() {
let rt = Runtime::from_current().unwrap();
let dtr = DistributedRuntime::from_settings(rt.clone()).await.unwrap();
let ns = dtr.namespace("test".to_string()).unwrap();
// Create a subscriber
let subscriber = ns.subscribe("test").await.unwrap();
// Publish a message
ns.publish("test", &"test_message".to_string())
.await
.unwrap();
// Receive the message
if let Some(msg) = subscriber.next().await {
let received = String::from_utf8(msg.payload.to_vec()).unwrap();
assert_eq!(received, "\"test_message\"");
}
rt.shutdown();
}
}
......@@ -56,3 +56,24 @@ pub trait EventPublisher {
// fn publisher(&self, event_name: impl AsRef<str>) -> Result<Publisher>;
// fn publisher_bytes(&self, event_name: impl AsRef<str>) -> &PublisherBytes;
}
/// A trait for subscribing to events in the event plane.
///
/// This trait provides methods to subscribe to events published on specific subjects.
#[async_trait]
pub trait EventSubscriber {
/// Subscribe to events with the given event name.
/// The `event_name` will be `.` concatenated with the base subject provided by the implementation.
/// Returns a subscriber that can be used to receive events.
async fn subscribe(
&self,
event_name: impl AsRef<str> + Send + Sync,
) -> Result<async_nats::Subscriber>;
/// Subscribe to events with the given event name and deserialize them to the specified type.
/// This is a convenience method that combines subscribe and deserialization.
async fn subscribe_with_type<T: for<'de> Deserialize<'de> + Send + 'static>(
&self,
event_name: impl AsRef<str> + Send + Sync,
) -> Result<impl futures::Stream<Item = Result<T>> + Send>;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment