"pcdet/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "c5a759f3d4c98f65e6cc07625e2f9005ef15fd28"
Commit 6e09681e authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

feat: expose Python binding for KVEventPublisher. Use event pub/sub trait for KV events (#169)

parent df51a622
...@@ -148,7 +148,7 @@ fn dynamo_create_kv_publisher( ...@@ -148,7 +148,7 @@ fn dynamo_create_kv_publisher(
{ {
Ok(drt) => { Ok(drt) => {
let backend = drt.namespace(namespace)?.component(component)?; let backend = drt.namespace(namespace)?.component(component)?;
KvEventPublisher::new(drt.clone(), backend, worker_id, kv_block_size) KvEventPublisher::new(backend, worker_id, kv_block_size)
} }
Err(e) => Err(e), Err(e) => Err(e),
} }
......
...@@ -77,6 +77,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -77,6 +77,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::kv::EndpointKvMetrics>()?; m.add_class::<llm::kv::EndpointKvMetrics>()?;
m.add_class::<llm::kv::AggregatedMetrics>()?; m.add_class::<llm::kv::AggregatedMetrics>()?;
m.add_class::<llm::kv::KvMetricsAggregator>()?; m.add_class::<llm::kv::KvMetricsAggregator>()?;
m.add_class::<llm::kv::KvEventPublisher>()?;
m.add_class::<http::HttpService>()?; m.add_class::<http::HttpService>()?;
m.add_class::<http::HttpError>()?; m.add_class::<http::HttpError>()?;
m.add_class::<http::HttpAsyncEngine>()?; m.add_class::<http::HttpAsyncEngine>()?;
...@@ -227,10 +228,6 @@ impl Component { ...@@ -227,10 +228,6 @@ impl Component {
Ok(()) Ok(())
}) })
} }
fn event_subject(&self, name: String) -> String {
self.inner.event_subject(name)
}
} }
#[pymethods] #[pymethods]
......
...@@ -17,8 +17,11 @@ use std::collections::HashMap; ...@@ -17,8 +17,11 @@ use std::collections::HashMap;
use super::*; use super::*;
use llm_rs::kv_router::indexer::KvIndexerInterface; use llm_rs::kv_router::indexer::KvIndexerInterface;
use rs::traits::events::EventSubscriber;
use tracing; use tracing;
use llm_rs::kv_router::{indexer::compute_block_hash_for_seq, protocols::*};
#[pyclass] #[pyclass]
pub(crate) struct KvRouter { pub(crate) struct KvRouter {
inner: Arc<llm_rs::kv_router::KvRouter>, inner: Arc<llm_rs::kv_router::KvRouter>,
...@@ -119,6 +122,114 @@ impl KvMetricsPublisher { ...@@ -119,6 +122,114 @@ impl KvMetricsPublisher {
} }
} }
#[pyclass]
pub(crate) struct KvEventPublisher {
inner: Arc<llm_rs::kv_router::publisher::KvEventPublisher>,
warning_count: u32,
}
#[pymethods]
impl KvEventPublisher {
#[new]
fn new(component: Component, worker_id: i64, kv_block_size: usize) -> PyResult<Self> {
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new(
component.inner.clone(),
worker_id,
kv_block_size,
)
.map_err(to_pyerr)?;
Ok(Self {
inner: inner.into(),
warning_count: 0,
})
}
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (event_id, token_ids, num_block_tokens, block_hashes, lora_id, parent_hash=None))]
fn publish_stored(
&mut self,
_py: Python,
event_id: u64,
token_ids: Vec<u32>,
num_block_tokens: Vec<u64>,
block_hashes: Vec<u64>,
lora_id: u64,
parent_hash: Option<u64>,
) -> PyResult<()> {
let event = KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: parent_hash.map(ExternalSequenceBlockHash),
blocks: self.create_stored_blocks(
&token_ids,
&num_block_tokens,
&block_hashes,
lora_id,
),
}),
};
self.inner.publish(event).map_err(to_pyerr)
}
fn publish_removed(&self, _py: Python, event_id: u64, block_hashes: Vec<u64>) -> PyResult<()> {
let block_hashes: Vec<ExternalSequenceBlockHash> = block_hashes
.iter()
.map(|&v| ExternalSequenceBlockHash(v))
.collect();
let event = KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
};
self.inner.publish(event).map_err(to_pyerr)
}
}
impl KvEventPublisher {
fn create_stored_block_from_parts(
&self,
block_hash: u64,
token_ids: &[u32],
_lora_id: u64,
) -> KvCacheStoredBlockData {
let tokens_hash = compute_block_hash_for_seq(token_ids, self.inner.kv_block_size())[0];
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(block_hash),
tokens_hash,
}
}
fn create_stored_blocks(
&mut self,
token_ids: &[u32],
num_block_tokens: &[u64],
block_hashes: &[u64],
lora_id: u64,
) -> Vec<KvCacheStoredBlockData> {
let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new();
let mut token_offset: usize = 0;
for (num_tokens_it, block_hash_it) in num_block_tokens.iter().zip(block_hashes.iter()) {
if (self.warning_count < 3) && (*num_tokens_it != self.inner.kv_block_size() as u64) {
tracing::warn!(
"Block not published. Block size must be {} tokens to be published. Block size is: {}",
self.inner.kv_block_size(),
*num_tokens_it
);
self.warning_count += 1;
break;
}
let tokens = &token_ids[token_offset..(token_offset + *num_tokens_it as usize)];
blocks.push(self.create_stored_block_from_parts(*block_hash_it, tokens, lora_id));
token_offset += *num_tokens_it as usize;
}
blocks
}
}
#[pyclass] #[pyclass]
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct OverlapScores { pub(crate) struct OverlapScores {
...@@ -149,21 +260,17 @@ impl KvIndexer { ...@@ -149,21 +260,17 @@ impl KvIndexer {
fn new(component: Component, kv_block_size: usize) -> PyResult<Self> { fn new(component: Component, kv_block_size: usize) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime(); let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async { runtime.block_on(async {
let kv_subject = component
.inner
.event_subject(llm_rs::kv_router::KV_EVENT_SUBJECT);
let inner: Arc<llm_rs::kv_router::indexer::KvIndexer> = let inner: Arc<llm_rs::kv_router::indexer::KvIndexer> =
llm_rs::kv_router::indexer::KvIndexer::new( llm_rs::kv_router::indexer::KvIndexer::new(
component.inner.drt().runtime().child_token(), component.inner.drt().runtime().child_token(),
kv_block_size, kv_block_size,
) )
.into(); .into();
// [gluo TODO] try subscribe_with_type::<RouterEvent>,
// error checking below will be different.
let mut kv_events_rx = component let mut kv_events_rx = component
.inner .inner
.drt() .subscribe(llm_rs::kv_router::KV_EVENT_SUBJECT)
.nats_client()
.client()
.subscribe(kv_subject)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
let kv_events_tx = inner.event_sender(); let kv_events_tx = inner.event_sender();
......
...@@ -102,11 +102,6 @@ class Component: ...@@ -102,11 +102,6 @@ class Component:
""" """
... ...
def event_subject(self, name: str) -> str:
"""
Create an event subject
"""
...
class Endpoint: class Endpoint:
""" """
...@@ -354,6 +349,30 @@ class KvMetricsAggregator: ...@@ -354,6 +349,30 @@ class KvMetricsAggregator:
""" """
... ...
class KvEventPublisher:
"""
A KV event publisher will publish KV events corresponding to the component.
"""
...
def __init__(self, component: Component, worker_id: int, kv_block_size: int) -> None:
"""
Create a `KvEventPublisher` object
"""
def publish_stored(self, event_id, int, token_ids: List[int], num_block_tokens: List[int], block_hashes: List[int], lora_id: int, parent_hash: Optional[int] = None) -> None:
"""
Publish a KV stored event.
"""
...
def publish_removed(self, event_id, int, block_hashes: List[int]) -> None:
"""
Publish a KV removed event.
"""
...
class HttpService: class HttpService:
""" """
A HTTP service for dynamo applications. A HTTP service for dynamo applications.
......
...@@ -18,6 +18,7 @@ from dynamo._core import DisaggregatedRouter as DisaggregatedRouter ...@@ -18,6 +18,7 @@ from dynamo._core import DisaggregatedRouter as DisaggregatedRouter
from dynamo._core import HttpAsyncEngine as HttpAsyncEngine from dynamo._core import HttpAsyncEngine as HttpAsyncEngine
from dynamo._core import HttpError as HttpError from dynamo._core import HttpError as HttpError
from dynamo._core import HttpService as HttpService from dynamo._core import HttpService as HttpService
from dynamo._core import KvEventPublisher as KvEventPublisher
from dynamo._core import KvIndexer as KvIndexer from dynamo._core import KvIndexer as KvIndexer
from dynamo._core import KvMetricsAggregator as KvMetricsAggregator from dynamo._core import KvMetricsAggregator as KvMetricsAggregator
from dynamo._core import KvMetricsPublisher as KvMetricsPublisher from dynamo._core import KvMetricsPublisher as KvMetricsPublisher
......
...@@ -24,8 +24,8 @@ from pydantic import BaseModel, ValidationError ...@@ -24,8 +24,8 @@ from pydantic import BaseModel, ValidationError
# import * causes "unable to detect undefined names" # import * causes "unable to detect undefined names"
from dynamo._core import Backend as Backend from dynamo._core import Backend as Backend
from dynamo._core import Client as Client from dynamo._core import Client as Client
from dynamo._core import Component as Component
from dynamo._core import DistributedRuntime as DistributedRuntime from dynamo._core import DistributedRuntime as DistributedRuntime
from dynamo._core import KvRouter as KvRouter
from dynamo._core import ModelDeploymentCard as ModelDeploymentCard from dynamo._core import ModelDeploymentCard as ModelDeploymentCard
from dynamo._core import OAIChatPreprocessor as OAIChatPreprocessor from dynamo._core import OAIChatPreprocessor as OAIChatPreprocessor
......
...@@ -24,8 +24,13 @@ from typing import List ...@@ -24,8 +24,13 @@ from typing import List
import pytest import pytest
from dynamo.llm import KvIndexer, KvMetricsAggregator, KvMetricsPublisher from dynamo.llm import (
from dynamo.runtime import DistributedRuntime KvEventPublisher,
KvIndexer,
KvMetricsAggregator,
KvMetricsPublisher,
)
from dynamo.runtime import Component, DistributedRuntime
pytestmark = pytest.mark.pre_merge pytestmark = pytest.mark.pre_merge
...@@ -64,14 +69,14 @@ async def test_event_handler(distributed_runtime): ...@@ -64,14 +69,14 @@ async def test_event_handler(distributed_runtime):
kv_block_size = 32 kv_block_size = 32
namespace = "kv_test" namespace = "kv_test"
component = "event" component = "event"
kv_listener = distributed_runtime.namespace(namespace).component(component)
await kv_listener.create_service()
# publisher # publisher
worker_id = 233 worker_id = 233
event_publisher = EventPublisher(namespace, component, worker_id, kv_block_size) event_publisher = EventPublisher(kv_listener, worker_id, kv_block_size)
# indexer # indexer
kv_listener = distributed_runtime.namespace(namespace).component(component)
await kv_listener.create_service()
indexer = KvIndexer(kv_listener, kv_block_size) indexer = KvIndexer(kv_listener, kv_block_size)
test_token = [3] * kv_block_size test_token = [3] * kv_block_size
...@@ -93,16 +98,48 @@ async def test_event_handler(distributed_runtime): ...@@ -93,16 +98,48 @@ async def test_event_handler(distributed_runtime):
scores = await indexer.find_matches_for_request(test_token, lora_id) scores = await indexer.find_matches_for_request(test_token, lora_id)
assert not scores.scores assert not scores.scores
event_publisher.shutdown()
class EventPublisher:
def __init__(self, component: Component, worker_id: int, kv_block_size: int):
self.publisher = KvEventPublisher(component, worker_id, kv_block_size)
self.event_id_counter = 0
self.block_hashes: List[int] = []
def store_event(self, tokens, lora_id):
parent_hash = self.event_id_counter if self.event_id_counter > 0 else None
self.publisher.publish_stored(
self.event_id_counter, # event_id
tokens, # token_ids
[
len(tokens),
], # num_block_tokens
[
self.event_id_counter,
], # block_hashes
lora_id, # lora_id
parent_hash, # parent_hash
)
self.block_hashes.append(self.event_id_counter)
self.event_id_counter += 1
def remove_event(self):
self.publisher.publish_removed(
self.event_id_counter, # event_id
[
self.block_hashes[-1],
], # block_hashes
)
self.event_id_counter += 1
# [TODO] to be deprecated
# KV events # KV events
class DynamoResult: class DynamoResult:
OK = 0 OK = 0
ERR = 1 ERR = 1
class EventPublisher: class CtypesEventPublisher:
def __init__( def __init__(
self, namespace: str, component: str, worker_id: int, kv_block_size: int self, namespace: str, component: str, worker_id: int, kv_block_size: int
): ):
......
...@@ -29,14 +29,18 @@ pub mod scoring; ...@@ -29,14 +29,18 @@ pub mod scoring;
use crate::kv_router::{ use crate::kv_router::{
indexer::{KvIndexer, KvIndexerInterface, RouterEvent}, indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
metrics_aggregator::collect_endpoints, metrics_aggregator::collect_endpoints_task,
scheduler::KvScheduler, scheduler::KvScheduler,
scoring::ProcessedEndpoints, scoring::ProcessedEndpoints,
}; };
// this should be discovered from the backend use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber};
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
pub const KV_EVENT_SUBJECT: &str = "kv_events"; pub const KV_EVENT_SUBJECT: &str = "kv_events";
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate"; pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";
pub struct KvRouter { pub struct KvRouter {
// properties of request plane // properties of request plane
...@@ -54,40 +58,27 @@ pub struct KvRouter { ...@@ -54,40 +58,27 @@ pub struct KvRouter {
impl KvRouter { impl KvRouter {
pub async fn from_runtime( pub async fn from_runtime(
runtime: DistributedRuntime, runtime: DistributedRuntime,
backend: Component, component: Component,
kv_block_size: usize, kv_block_size: usize,
) -> Result<Arc<Self>> { ) -> Result<Arc<Self>> {
let nats_client = runtime.nats_client(); let namespace = runtime.namespace(component.namespace().name())?;
let service_name = backend.service_name();
let kv_subject = backend.event_subject(KV_EVENT_SUBJECT); tracing::info!("Component Namespace {}", component.namespace());
let namespace = runtime.namespace(backend.namespace())?; tracing::info!("Component Service Name {}", component.service_name());
tracing::info!("KV Subject {}.{}", component.subject(), KV_EVENT_SUBJECT);
tracing::info!("Component Namespace {}", backend.namespace()); Self::new(component, namespace, kv_block_size).await
tracing::info!("Component Service Name {}", service_name);
tracing::info!("KV Subject {}", kv_subject);
Self::new(
nats_client,
service_name,
kv_subject,
namespace,
kv_block_size,
)
.await
} }
pub async fn new( pub async fn new(
nats_client: dynamo_runtime::transports::nats::Client, component: Component,
service_name: String,
kv_subject: String,
namespace: Namespace, namespace: Namespace,
kv_block_size: usize, kv_block_size: usize,
) -> Result<Arc<Self>> { ) -> Result<Arc<Self>> {
let cancellation_token = CancellationToken::new(); let cancellation_token = CancellationToken::new();
let (ep_tx, ep_rx) = tokio::sync::mpsc::channel(128); let (ep_tx, ep_rx) = tokio::sync::mpsc::channel(128);
tokio::spawn(collect_endpoints( tokio::spawn(collect_endpoints_task(
nats_client.clone(), component.clone(),
service_name.clone(),
ep_tx, ep_tx,
cancellation_token.clone(), cancellation_token.clone(),
)); ));
...@@ -95,8 +86,9 @@ impl KvRouter { ...@@ -95,8 +86,9 @@ impl KvRouter {
let indexer = KvIndexer::new(cancellation_token.clone(), kv_block_size); let indexer = KvIndexer::new(cancellation_token.clone(), kv_block_size);
let scheduler = KvScheduler::start(ep_rx, namespace, kv_block_size).await?; let scheduler = KvScheduler::start(ep_rx, namespace, kv_block_size).await?;
tracing::debug!("subscribing to kv events: {}", kv_subject); // [gluo TODO] try subscribe_with_type::<RouterEvent>,
let mut kv_events_rx = nats_client.client().subscribe(kv_subject).await?; // error checking below will be different.
let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?;
let kv_events_tx = indexer.event_sender(); let kv_events_tx = indexer.event_sender();
tokio::spawn(async move { tokio::spawn(async move {
...@@ -120,7 +112,7 @@ impl KvRouter { ...@@ -120,7 +112,7 @@ impl KvRouter {
}); });
Ok(Arc::new(Self { Ok(Arc::new(Self {
service_name, service_name: component.service_name(),
cancellation_token, cancellation_token,
scheduler, scheduler,
indexer, indexer,
......
...@@ -16,11 +16,12 @@ ...@@ -16,11 +16,12 @@
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
pub use crate::kv_router::protocols::ForwardPassMetrics; pub use crate::kv_router::protocols::ForwardPassMetrics;
use crate::kv_router::KV_METRICS_ENDPOINT;
use crate::kv_router::scheduler::{Endpoint, Service}; use crate::kv_router::scheduler::Endpoint;
use crate::kv_router::ProcessedEndpoints; use crate::kv_router::ProcessedEndpoints;
use dynamo_runtime::component::Component; use dynamo_runtime::component::Component;
use std::time::Duration; use dynamo_runtime::{service::EndpointInfo, utils::Duration, Result};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
pub struct KvMetricsAggregator { pub struct KvMetricsAggregator {
...@@ -32,9 +33,8 @@ impl KvMetricsAggregator { ...@@ -32,9 +33,8 @@ impl KvMetricsAggregator {
pub async fn new(component: Component, cancellation_token: CancellationToken) -> Self { pub async fn new(component: Component, cancellation_token: CancellationToken) -> Self {
let (ep_tx, mut ep_rx) = tokio::sync::mpsc::channel(128); let (ep_tx, mut ep_rx) = tokio::sync::mpsc::channel(128);
tokio::spawn(collect_endpoints( tokio::spawn(collect_endpoints_task(
component.drt().nats_client().clone(), component.clone(),
component.service_name(),
ep_tx, ep_tx,
cancellation_token.clone(), cancellation_token.clone(),
)); ));
...@@ -80,13 +80,41 @@ impl KvMetricsAggregator { ...@@ -80,13 +80,41 @@ impl KvMetricsAggregator {
} }
} }
/// [gluo TODO] 'collect_endpoints' is from component/metrics,
/// should consolidate these functions into generic metrics aggregator
/// functions and shared by KvMetricsAggregator and component/metrics.
/// Collect endpoints from a component
pub async fn collect_endpoints( pub async fn collect_endpoints(
nats_client: dynamo_runtime::transports::nats::Client, component: &Component,
service_name: String, subject: &str,
timeout: Duration,
) -> Result<Vec<EndpointInfo>> {
// Collect stats from each backend
let stream = component.scrape_stats(timeout).await?;
// Filter the stats by the service subject
let endpoints = stream
.into_endpoints()
.filter(|e| e.subject.starts_with(subject))
.collect::<Vec<_>>();
tracing::debug!("Endpoints: {endpoints:?}");
if endpoints.is_empty() {
tracing::warn!("No endpoints found matching subject {subject}");
}
Ok(endpoints)
}
pub async fn collect_endpoints_task(
component: Component,
ep_tx: tokio::sync::mpsc::Sender<ProcessedEndpoints>, ep_tx: tokio::sync::mpsc::Sender<ProcessedEndpoints>,
cancel: CancellationToken, cancel: CancellationToken,
) { ) {
let backoff_delay = Duration::from_millis(100); let backoff_delay = Duration::from_millis(100);
let scrape_timeout = Duration::from_millis(300);
let endpoint = component.endpoint(KV_METRICS_ENDPOINT);
let service_subject = endpoint.subject();
loop { loop {
tokio::select! { tokio::select! {
...@@ -95,48 +123,41 @@ pub async fn collect_endpoints( ...@@ -95,48 +123,41 @@ pub async fn collect_endpoints(
break; break;
} }
_ = tokio::time::sleep(backoff_delay) => { _ = tokio::time::sleep(backoff_delay) => {
tracing::trace!("collecting endpoints for service: {}", service_name); tracing::trace!("collecting endpoints for service: {}", service_subject);
let values = match nats_client let unfiltered_endpoints =
.get_endpoints(&service_name, Duration::from_millis(300)) match collect_endpoints(&component, &service_subject, scrape_timeout).await
.await {
{ Ok(v) => v,
Ok(v) => v,
Err(e) => {
tracing::warn!("Failed to retrieve endpoints for {}: {:?}", service_name, e);
continue;
}
};
tracing::debug!("values: {:?}", values);
let services: Vec<Service> = values
.into_iter()
.filter(|v| !v.is_empty())
.filter_map(|v| match serde_json::from_slice::<Service>(&v) {
Ok(service) => Some(service),
Err(e) => { Err(e) => {
tracing::warn!("For value: {:?} \nFailed to parse service: {:?}", v, e); tracing::warn!("Failed to retrieve endpoints for {}: {:?}", service_subject, e);
None continue;
} }
}) };
.collect(); tracing::debug!("unfiltered endpoints: {:?}", unfiltered_endpoints);
tracing::debug!("services: {:?}", services);
let endpoints: Vec<Endpoint> = services let endpoints: Vec<Endpoint> = unfiltered_endpoints
.into_iter() .into_iter()
.flat_map(|s| s.endpoints)
.filter(|s| s.data.is_some()) .filter(|s| s.data.is_some())
.map(|s| Endpoint { .filter_map(|s|
name: s.name, match s.data.unwrap().decode::<ForwardPassMetrics>() {
subject: s.subject, Ok(data) => Some(Endpoint {
data: s.data.unwrap(), name: s.name,
}) subject: s.subject,
data,
}),
Err(e) => {
tracing::debug!("skip endpoint data that can't be parsed as ForwardPassMetrics: {:?}", e);
None
}
}
)
.collect(); .collect();
tracing::debug!("endpoints: {:?}", endpoints); tracing::debug!("endpoints: {:?}", endpoints);
tracing::trace!( tracing::trace!(
"found {} endpoints for service: {}", "found {} endpoints for service: {}",
endpoints.len(), endpoints.len(),
service_name service_subject
); );
let processed = ProcessedEndpoints::new(endpoints); let processed = ProcessedEndpoints::new(endpoints);
......
...@@ -13,8 +13,9 @@ ...@@ -13,8 +13,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use crate::kv_router::{indexer::RouterEvent, protocols::*, KV_EVENT_SUBJECT}; use crate::kv_router::{indexer::RouterEvent, protocols::*, KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT};
use async_trait::async_trait; use async_trait::async_trait;
use dynamo_runtime::traits::{events::EventPublisher, DistributedRuntimeProvider};
use dynamo_runtime::{ use dynamo_runtime::{
component::Component, component::Component,
pipeline::{ pipeline::{
...@@ -22,7 +23,7 @@ use dynamo_runtime::{ ...@@ -22,7 +23,7 @@ use dynamo_runtime::{
SingleIn, SingleIn,
}, },
protocols::annotated::Annotated, protocols::annotated::Annotated,
DistributedRuntime, Error, Result, Error, Result,
}; };
use futures::stream; use futures::stream;
use std::sync::Arc; use std::sync::Arc;
...@@ -35,16 +36,11 @@ pub struct KvEventPublisher { ...@@ -35,16 +36,11 @@ pub struct KvEventPublisher {
} }
impl KvEventPublisher { impl KvEventPublisher {
pub fn new( pub fn new(component: Component, worker_id: i64, kv_block_size: usize) -> Result<Self> {
drt: DistributedRuntime,
backend: Component,
worker_id: i64,
kv_block_size: usize,
) -> Result<Self> {
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>(); let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let p = KvEventPublisher { tx, kv_block_size }; let p = KvEventPublisher { tx, kv_block_size };
start_publish_task(drt, backend, worker_id, rx); start_publish_task(component, worker_id, rx);
Ok(p) Ok(p)
} }
...@@ -59,21 +55,18 @@ impl KvEventPublisher { ...@@ -59,21 +55,18 @@ impl KvEventPublisher {
} }
fn start_publish_task( fn start_publish_task(
drt: DistributedRuntime, component: Component,
backend: Component,
worker_id: i64, worker_id: i64,
mut rx: mpsc::UnboundedReceiver<KvCacheEvent>, mut rx: mpsc::UnboundedReceiver<KvCacheEvent>,
) { ) {
let client = drt.nats_client().client().clone(); let component_clone = component.clone();
let kv_subject = backend.event_subject(KV_EVENT_SUBJECT); log::info!("Publishing KV Events to subject: {}", KV_EVENT_SUBJECT);
log::info!("Publishing KV Events to subject: {}", kv_subject);
_ = drt.runtime().secondary().spawn(async move { _ = component.drt().runtime().secondary().spawn(async move {
while let Some(event) = rx.recv().await { while let Some(event) = rx.recv().await {
let router_event = RouterEvent::new(worker_id, event); let router_event = RouterEvent::new(worker_id, event);
let data = serde_json::to_string(&router_event).unwrap(); component_clone
client .publish(KV_EVENT_SUBJECT, &router_event)
.publish(kv_subject.to_string(), data.into())
.await .await
.unwrap(); .unwrap();
} }
...@@ -105,7 +98,7 @@ impl KvMetricsPublisher { ...@@ -105,7 +98,7 @@ impl KvMetricsPublisher {
let handler = Ingress::for_engine(handler)?; let handler = Ingress::for_engine(handler)?;
component component
.endpoint("load_metrics") .endpoint(KV_METRICS_ENDPOINT)
.endpoint_builder() .endpoint_builder()
.stats_handler(move |_| { .stats_handler(move |_| {
let metrics = metrics_rx.borrow_and_update().clone(); let metrics = metrics_rx.borrow_and_update().clone();
......
...@@ -43,13 +43,8 @@ pub enum KvSchedulerError { ...@@ -43,13 +43,8 @@ pub enum KvSchedulerError {
SubscriberShutdown, SubscriberShutdown,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] /// [gluo FIXME] exactly the same as EndpointInfo except that 'data'
pub struct FlexibleEndpoint { /// is cleaned (not optional)
pub name: String,
pub subject: String,
pub data: Option<ForwardPassMetrics>,
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Endpoint { pub struct Endpoint {
pub name: String, pub name: String,
...@@ -72,24 +67,6 @@ impl Endpoint { ...@@ -72,24 +67,6 @@ impl Endpoint {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FlexibleService {
pub name: String,
pub id: String,
pub version: String,
pub started: String,
pub endpoints: Vec<FlexibleEndpoint>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Service {
pub name: String,
pub id: String,
pub version: String,
pub started: String,
pub endpoints: Vec<FlexibleEndpoint>,
}
pub struct SchedulingRequest { pub struct SchedulingRequest {
isl_tokens: usize, isl_tokens: usize,
overlap: OverlapScores, overlap: OverlapScores,
......
...@@ -61,6 +61,8 @@ use std::{collections::HashMap, sync::Arc}; ...@@ -61,6 +61,8 @@ use std::{collections::HashMap, sync::Arc};
use validator::{Validate, ValidationError}; use validator::{Validate, ValidationError};
mod client; mod client;
#[allow(clippy::module_inception)]
mod component;
mod endpoint; mod endpoint;
mod namespace; mod namespace;
mod registry; mod registry;
...@@ -115,12 +117,12 @@ pub struct Component { ...@@ -115,12 +117,12 @@ pub struct Component {
// todo - restrict the namespace to a-z0-9-_A-Z // todo - restrict the namespace to a-z0-9-_A-Z
/// Namespace /// Namespace
#[builder(setter(into))] #[builder(setter(into))]
namespace: String, namespace: Namespace,
} }
impl std::fmt::Display for Component { impl std::fmt::Display for Component {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}.{}", self.namespace, self.name) write!(f, "{}.{}", self.namespace.name(), self.name)
} }
} }
...@@ -138,30 +140,21 @@ impl RuntimeProvider for Component { ...@@ -138,30 +140,21 @@ impl RuntimeProvider for Component {
impl Component { impl Component {
pub fn etcd_path(&self) -> String { pub fn etcd_path(&self) -> String {
format!("{}/components/{}", self.namespace, self.name) format!("{}/components/{}", self.namespace.name(), self.name)
} }
pub fn service_name(&self) -> String { pub fn service_name(&self) -> String {
Slug::from_string(format!("{}|{}", self.namespace, self.name)).to_string() Slug::from_string(format!("{}|{}", self.namespace.name(), self.name)).to_string()
}
// todo - move to EventPlane
pub fn event_subject(&self, name: impl AsRef<str>) -> String {
format!("{}.events.{}", self.service_name(), name.as_ref())
} }
pub fn path(&self) -> String { pub fn path(&self) -> String {
format!("{}/{}", self.namespace, self.name) format!("{}/{}", self.namespace.name(), self.name)
} }
pub fn namespace(&self) -> &str { pub fn namespace(&self) -> &Namespace {
&self.namespace &self.namespace
} }
pub fn drt(&self) -> &DistributedRuntime {
&self.drt
}
pub fn endpoint(&self, endpoint: impl Into<String>) -> Endpoint { pub fn endpoint(&self, endpoint: impl Into<String>) -> Endpoint {
Endpoint { Endpoint {
component: self.clone(), component: self.clone(),
...@@ -300,6 +293,12 @@ impl RuntimeProvider for Namespace { ...@@ -300,6 +293,12 @@ impl RuntimeProvider for Namespace {
} }
} }
impl std::fmt::Display for Namespace {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name)
}
}
impl Namespace { impl Namespace {
pub(crate) fn new(runtime: DistributedRuntime, name: String) -> Result<Self> { pub(crate) fn new(runtime: DistributedRuntime, name: String) -> Result<Self> {
Ok(NamespaceBuilder::default() Ok(NamespaceBuilder::default()
...@@ -312,7 +311,7 @@ impl Namespace { ...@@ -312,7 +311,7 @@ impl Namespace {
pub fn component(&self, name: impl Into<String>) -> Result<Component> { pub fn component(&self, name: impl Into<String>) -> Result<Component> {
Ok(ComponentBuilder::from_runtime(self.runtime.clone()) Ok(ComponentBuilder::from_runtime(self.runtime.clone())
.name(name) .name(name)
.namespace(self.name.clone()) .namespace(self.clone())
.build()?) .build()?)
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context;
use async_trait::async_trait;
use futures::stream::StreamExt;
use futures::{Stream, TryStreamExt};
use super::*;
use crate::traits::events::{EventPublisher, EventSubscriber};
#[async_trait]
impl EventPublisher for Component {
fn subject(&self) -> String {
format!("namespace.{}.component.{}", self.namespace.name, self.name)
}
async fn publish(
&self,
event_name: impl AsRef<str> + Send + Sync,
event: &(impl Serialize + Send + Sync),
) -> Result<()> {
let bytes = serde_json::to_vec(event)?;
self.publish_bytes(event_name, bytes).await
}
async fn publish_bytes(
&self,
event_name: impl AsRef<str> + Send + Sync,
bytes: Vec<u8>,
) -> Result<()> {
let subject = format!("{}.{}", self.subject(), event_name.as_ref());
Ok(self
.drt()
.nats_client()
.client()
.publish(subject, bytes.into())
.await?)
}
}
#[async_trait]
impl EventSubscriber for Component {
async fn subscribe(
&self,
event_name: impl AsRef<str> + Send + Sync,
) -> Result<async_nats::Subscriber> {
let subject = format!("{}.{}", self.subject(), event_name.as_ref());
Ok(self.drt().nats_client().client().subscribe(subject).await?)
}
async fn subscribe_with_type<T: for<'de> Deserialize<'de> + Send + 'static>(
&self,
event_name: impl AsRef<str> + Send + Sync,
) -> Result<impl Stream<Item = Result<T>> + Send> {
let subscriber = self.subscribe(event_name).await?;
// Transform the subscriber into a stream of deserialized events
let stream = subscriber.map(move |msg| {
serde_json::from_slice::<T>(&msg.payload)
.with_context(|| format!("Failed to deserialize event payload: {:?}", msg.payload))
});
Ok(stream)
}
}
#[cfg(feature = "integration")]
#[cfg(test)]
mod tests {
use super::*;
// todo - make a distributed runtime fixture
// todo - two options - fully mocked or integration test
#[tokio::test]
async fn test_publish() {
let rt = Runtime::from_current().unwrap();
let dtr = DistributedRuntime::from_settings(rt.clone()).await.unwrap();
let ns = dtr.namespace("test".to_string()).unwrap();
let cp = ns.component("component".to_string()).unwrap();
cp.publish("test", &"test".to_string()).await.unwrap();
rt.shutdown();
}
#[tokio::test]
async fn test_subscribe() {
let rt = Runtime::from_current().unwrap();
let dtr = DistributedRuntime::from_settings(rt.clone()).await.unwrap();
let ns = dtr.namespace("test".to_string()).unwrap();
let cp = ns.component("component".to_string()).unwrap();
// Create a subscriber
let mut subscriber = ns.subscribe("test").await.unwrap();
// Publish a message
cp.publish("test", &"test_message".to_string())
.await
.unwrap();
// Receive the message
if let Some(msg) = subscriber.next().await {
let received = String::from_utf8(msg.payload.to_vec()).unwrap();
assert_eq!(received, "\"test_message\"");
}
rt.shutdown();
}
}
...@@ -113,7 +113,7 @@ impl EndpointConfigBuilder { ...@@ -113,7 +113,7 @@ impl EndpointConfigBuilder {
let info = ComponentEndpointInfo { let info = ComponentEndpointInfo {
component: endpoint.component.name.clone(), component: endpoint.component.name.clone(),
endpoint: endpoint.name.clone(), endpoint: endpoint.name.clone(),
namespace: endpoint.component.namespace.clone(), namespace: endpoint.component.namespace.name.clone(),
lease_id: lease.id(), lease_id: lease.id(),
transport: TransportType::NatsTcp(endpoint.subject_to(lease.id())), transport: TransportType::NatsTcp(endpoint.subject_to(lease.id())),
}; };
......
...@@ -101,7 +101,7 @@ mod tests { ...@@ -101,7 +101,7 @@ mod tests {
let ns = dtr.namespace("test".to_string()).unwrap(); let ns = dtr.namespace("test".to_string()).unwrap();
// Create a subscriber // Create a subscriber
let subscriber = ns.subscribe("test").await.unwrap(); let mut subscriber = ns.subscribe("test").await.unwrap();
// Publish a message // Publish a message
ns.publish("test", &"test_message".to_string()) ns.publish("test", &"test_message".to_string())
......
...@@ -73,12 +73,28 @@ impl EndpointInfo { ...@@ -73,12 +73,28 @@ impl EndpointInfo {
i64::from_str_radix(id, 16).map_err(|e| error!("Invalid id format: {}", e)) i64::from_str_radix(id, 16).map_err(|e| error!("Invalid id format: {}", e))
} }
} }
// TODO: This is _really_ close to the async_nats::service::Stats object,
// but it's missing a few fields like "name", so use a temporary struct
// for easy deserialization. Ideally, this type already exists or can
// be exposed in the library somewhere.
/// Stats structure returned from NATS service API
#[derive(Debug, Clone, Serialize, Deserialize, Dissolve)] #[derive(Debug, Clone, Serialize, Deserialize, Dissolve)]
pub struct Metrics(pub serde_json::Value); pub struct Metrics {
// Standard NATS Service API fields
pub average_processing_time: f64,
pub last_error: String,
pub num_errors: u64,
pub num_requests: u64,
pub processing_time: u64,
pub queue_group: String,
// Field containing custom stats handler data
pub data: serde_json::Value,
}
impl Metrics { impl Metrics {
pub fn decode<T: for<'de> Deserialize<'de>>(self) -> Result<T> { pub fn decode<T: for<'de> Deserialize<'de>>(self) -> Result<T> {
serde_json::from_value(self.0).map_err(Into::into) serde_json::from_value(self.data).map_err(Into::into)
} }
} }
...@@ -153,12 +169,28 @@ mod tests { ...@@ -153,12 +169,28 @@ mod tests {
EndpointInfo { EndpointInfo {
name: "endpoint1".to_string(), name: "endpoint1".to_string(),
subject: "subject1".to_string(), subject: "subject1".to_string(),
data: Some(Metrics(serde_json::json!({"key": "value1"}))), data: Some(Metrics {
average_processing_time: 0.1,
last_error: "none".to_string(),
num_errors: 0,
num_requests: 10,
processing_time: 100,
queue_group: "group1".to_string(),
data: serde_json::json!({"key": "value1"}),
}),
}, },
EndpointInfo { EndpointInfo {
name: "endpoint2-foo".to_string(), name: "endpoint2-foo".to_string(),
subject: "subject2".to_string(), subject: "subject2".to_string(),
data: Some(Metrics(serde_json::json!({"key": "value1"}))), data: Some(Metrics {
average_processing_time: 0.1,
last_error: "none".to_string(),
num_errors: 0,
num_requests: 10,
processing_time: 100,
queue_group: "group1".to_string(),
data: serde_json::json!({"key": "value1"}),
}),
}, },
], ],
}, },
...@@ -171,12 +203,28 @@ mod tests { ...@@ -171,12 +203,28 @@ mod tests {
EndpointInfo { EndpointInfo {
name: "endpoint1".to_string(), name: "endpoint1".to_string(),
subject: "subject1".to_string(), subject: "subject1".to_string(),
data: Some(Metrics(serde_json::json!({"key": "value1"}))), data: Some(Metrics {
average_processing_time: 0.1,
last_error: "none".to_string(),
num_errors: 0,
num_requests: 10,
processing_time: 100,
queue_group: "group1".to_string(),
data: serde_json::json!({"key": "value1"}),
}),
}, },
EndpointInfo { EndpointInfo {
name: "endpoint2-bar".to_string(), name: "endpoint2-bar".to_string(),
subject: "subject2".to_string(), subject: "subject2".to_string(),
data: Some(Metrics(serde_json::json!({"key": "value2"}))), data: Some(Metrics {
average_processing_time: 0.1,
last_error: "none".to_string(),
num_errors: 0,
num_requests: 10,
processing_time: 100,
queue_group: "group1".to_string(),
data: serde_json::json!({"key": "value2"}),
}),
}, },
], ],
}, },
......
...@@ -108,35 +108,6 @@ impl Client { ...@@ -108,35 +108,6 @@ impl Client {
Ok(subscription) Ok(subscription)
} }
// todo - deprecate - move to service subscriber
pub async fn get_endpoints(
&self,
service_name: &str,
timeout: time::Duration,
) -> Result<Vec<Bytes>, anyhow::Error> {
let subject = format!("$SRV.STATS.{}", service_name);
let reply_subject = format!("_INBOX.{}", nuid::next());
let mut subscription = self.client.subscribe(reply_subject.clone()).await?;
let deadline = tokio::time::Instant::now() + timeout;
// Publish the request with the reply-to subject
self.client
.publish_with_reply(subject, reply_subject, "".into())
.await?;
// Set a timeout to gather responses
let mut responses = Vec::new();
// let mut response_stream = subscription.take_while(|_| futures::future::ready(true));
while let Ok(Some(message)) = time::timeout_at(deadline, subscription.next()).await {
// log::debug!("get endpoint received message before timeout: {:?}", message);
responses.push(message.payload);
}
Ok(responses)
}
// /// create a new stream // /// create a new stream
// async fn get_or_create_work_queue_stream( // async fn get_or_create_work_queue_stream(
// &self, // &self,
...@@ -272,35 +243,6 @@ impl Client { ...@@ -272,35 +243,6 @@ impl Client {
// Ok(()) // Ok(())
// } // }
// pub async fn get_endpoints(
// &self,
// service_name: &str,
// timeout: Duration,
// ) -> Result<Vec<Bytes>, anyhow::Error> {
// let subject = format!("$SRV.STATS.{}", service_name);
// let reply_subject = format!("_INBOX.{}", nuid::next());
// let mut subscription = self.client.subscribe(reply_subject.clone()).await?;
// // Publish the request with the reply-to subject
// self.client
// .publish_with_reply(subject, reply_subject, "".into())
// .await?;
// // Set a timeout to gather responses
// let mut responses = Vec::new();
// // let mut response_stream = subscription.take_while(|_| futures::future::ready(true));
// let start = time::Instant::now();
// while let Ok(Some(message)) = time::timeout(timeout, subscription.next()).await {
// responses.push(message.payload);
// if start.elapsed() > timeout {
// break;
// }
// }
// Ok(responses)
// }
// pub fn frontend_client(&self, request_id: String) -> SpecializedClient { // pub fn frontend_client(&self, request_id: String) -> SpecializedClient {
// SpecializedClient::new(self.client.clone(), ClientKind::Frontend, request_id) // SpecializedClient::new(self.client.clone(), ClientKind::Frontend, request_id)
// } // }
...@@ -691,35 +633,6 @@ mod tests { ...@@ -691,35 +633,6 @@ mod tests {
// assert_eq!(initial_work_queue_count, work_queue_count); // assert_eq!(initial_work_queue_count, work_queue_count);
// } // }
// pub async fn get_endpoints(
// &self,
// service_name: &str,
// timeout: Duration,
// ) -> Result<Vec<Bytes>, anyhow::Error> {
// let subject = format!("$SRV.STATS.{}", service_name);
// let reply_subject = format!("_INBOX.{}", nuid::next());
// let mut subscription = self.client.subscribe(reply_subject.clone()).await?;
// // Publish the request with the reply-to subject
// self.client
// .publish_with_reply(subject, reply_subject, "".into())
// .await?;
// // Set a timeout to gather responses
// let mut responses = Vec::new();
// // let mut response_stream = subscription.take_while(|_| futures::future::ready(true));
// let start = time::Instant::now();
// while let Ok(Some(message)) = time::timeout(timeout, subscription.next()).await {
// responses.push(message.payload);
// if start.elapsed() > timeout {
// break;
// }
// }
// Ok(responses)
// }
// async fn connect(config: Arc<Config>) -> Result<NatsClient> { // async fn connect(config: Arc<Config>) -> Result<NatsClient> {
// let client = ClientOptions::builder() // let client = ClientOptions::builder()
// .server(config.nats_address.clone()) // .server(config.nats_address.clone())
...@@ -852,35 +765,6 @@ mod tests { ...@@ -852,35 +765,6 @@ mod tests {
// pub fn service_builder(&self) -> NatsServiceBuilder { // pub fn service_builder(&self) -> NatsServiceBuilder {
// self.client.service_builder() // self.client.service_builder()
// } // }
// pub async fn get_endpoints(
// &self,
// service_name: &str,
// timeout: Duration,
// ) -> Result<Vec<Bytes>, anyhow::Error> {
// let subject = format!("$SRV.STATS.{}", service_name);
// let reply_subject = format!("_INBOX.{}", nuid::next());
// let mut subscription = self.client.subscribe(reply_subject.clone()).await?;
// // Publish the request with the reply-to subject
// self.client
// .publish_with_reply(subject, reply_subject, "".into())
// .await?;
// // Set a timeout to gather responses
// let mut responses = Vec::new();
// // let mut response_stream = subscription.take_while(|_| futures::future::ready(true));
// let start = tokio::time::Instant::now();
// while let Ok(Some(message)) = tokio::time::timeout(timeout, subscription.next()).await {
// responses.push(message.payload);
// if start.elapsed() > timeout {
// break;
// }
// }
// Ok(responses)
// }
// } // }
// #[derive(Debug, Clone, Serialize, Deserialize)] // #[derive(Debug, Clone, Serialize, Deserialize)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment