Unverified Commit 9210a26d authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

refactor: Refactor kv event publishers (#1287)

parent 39dcdf1f
......@@ -148,7 +148,7 @@ fn dynamo_create_kv_publisher(
{
Ok(drt) => {
let backend = drt.namespace(namespace)?.component(component)?;
KvEventPublisher::new(backend, worker_id, kv_block_size)
KvEventPublisher::new(backend, worker_id, kv_block_size, None)
}
Err(e) => Err(e),
}
......
......@@ -22,7 +22,7 @@ use rs::traits::events::EventSubscriber;
use tracing;
use llm_rs::kv_router::protocols::*;
use llm_rs::kv_router::publisher::create_stored_blocks;
use llm_rs::kv_router::publisher::{create_stored_blocks, KvEventSourceConfig};
#[pyclass]
pub(crate) struct KvRouter {
......@@ -165,21 +165,23 @@ impl ZmqKvEventPublisherConfig {
#[pyclass]
pub(crate) struct ZmqKvEventPublisher {
inner: llm_rs::kv_router::publisher::ZmqKvEventPublisher,
inner: llm_rs::kv_router::publisher::KvEventPublisher,
}
#[pymethods]
impl ZmqKvEventPublisher {
#[new]
fn new(component: Component, config: ZmqKvEventPublisherConfig) -> PyResult<Self> {
let mut inner =
llm_rs::kv_router::publisher::ZmqKvEventPublisher::new(config.kv_block_size);
inner.start_background_task(
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new(
component.inner,
config.worker_id,
config.zmq_endpoint,
config.zmq_topic,
);
config.kv_block_size,
Some(KvEventSourceConfig::Zmq {
endpoint: config.zmq_endpoint,
topic: config.zmq_topic,
}),
)
.map_err(to_pyerr)?;
Ok(Self { inner })
}
......@@ -203,8 +205,10 @@ impl KvEventPublisher {
component.inner,
worker_id,
kv_block_size,
None,
)
.map_err(to_pyerr)?;
Ok(Self {
inner: inner.into(),
kv_block_size,
......
......@@ -19,7 +19,7 @@ use crate::kv_router::{
KV_EVENT_SUBJECT, KV_METRICS_ENDPOINT,
};
use async_trait::async_trait;
use dynamo_runtime::traits::{events::EventPublisher, DistributedRuntimeProvider, RuntimeProvider};
use dynamo_runtime::traits::{events::EventPublisher, DistributedRuntimeProvider};
use dynamo_runtime::{
component::Component,
pipeline::{
......@@ -32,6 +32,7 @@ use dynamo_runtime::{
use futures::stream;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use rmp_serde as rmps;
use serde::Deserialize;
......@@ -44,173 +45,163 @@ use zeromq::{Socket, SocketRecv, SubSocket};
// KV Event Publishers -----------------------------------------------------
// -------------------------------------------------------------------------
pub struct KvEventPublisher {
kv_block_size: usize,
tx: mpsc::UnboundedSender<KvCacheEvent>,
/// Configure the source of KV events.
/// Currently, only ZMQ is supported.
pub enum KvEventSourceConfig {
Zmq { endpoint: String, topic: String },
}
impl KvEventPublisher {
pub fn new(component: Component, worker_id: i64, kv_block_size: usize) -> Result<Self> {
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let p = KvEventPublisher { tx, kv_block_size };
/// The source of KV events.
enum KvEventSource {
Zmq {
zmq_handle: tokio::task::JoinHandle<()>,
},
}
start_publish_task(component, worker_id, rx);
Ok(p)
}
impl KvEventSource {
/// Start the event source from a [`KvEventSourceConfig`].
fn start(
component: Component,
kv_block_size: usize,
source_config: KvEventSourceConfig,
cancellation_token: CancellationToken,
tx: mpsc::UnboundedSender<KvCacheEvent>,
) -> Result<Self> {
match source_config {
KvEventSourceConfig::Zmq { endpoint, topic } => {
let zmq_handle = component
.drt()
.runtime()
.secondary()
.spawn(start_zmq_listener(
endpoint,
topic,
tx,
cancellation_token.clone(),
kv_block_size,
));
pub fn publish(&self, event: KvCacheEvent) -> Result<(), mpsc::error::SendError<KvCacheEvent>> {
tracing::trace!("Publish event: {:?}", event);
self.tx.send(event)
Ok(KvEventSource::Zmq { zmq_handle })
}
}
pub fn kv_block_size(&self) -> usize {
self.kv_block_size
}
}
fn start_publish_task(
component: Component,
worker_id: i64,
mut rx: mpsc::UnboundedReceiver<KvCacheEvent>,
) {
let component_clone = component.clone();
tracing::info!("Publishing KV Events to subject: {}", KV_EVENT_SUBJECT);
_ = component.drt().runtime().secondary().spawn(async move {
while let Some(event) = rx.recv().await {
let router_event = RouterEvent::new(worker_id, event);
component_clone
.publish(KV_EVENT_SUBJECT, &router_event)
.await
.unwrap();
fn shutdown(&self) {
match self {
KvEventSource::Zmq { zmq_handle } => {
zmq_handle.abort();
}
}
}
});
}
// vLLM and SGLang use multi-processing to launch engine-core processes
// We use zmq to publish events from these processes to a socket
// For more info on zmq: https://zeromq.org/
// This publisher reads those events and publishes them to NATS
// The indexer will get the events from NATS and put them in the global prefix tree.
pub struct ZmqKvEventPublisher {
/// A publisher of KV events.
pub struct KvEventPublisher {
/// The size of the KV block.
kv_block_size: usize,
processor_handle: Option<tokio::task::JoinHandle<()>>,
zmq_handle: Option<tokio::task::JoinHandle<()>>,
zmq_token: Option<dynamo_runtime::CancellationToken>,
warning_count: Arc<AtomicU32>,
/// The source of KV events.
/// Can be `None` if all events provided through [`KvEventPublisher::publish`].
source: Option<KvEventSource>,
/// The cancellation token.
cancellation_token: CancellationToken,
/// The channel to send events to.
tx: mpsc::UnboundedSender<KvCacheEvent>,
}
impl ZmqKvEventPublisher {
pub fn new(kv_block_size: usize) -> Self {
Self {
kv_block_size,
processor_handle: None,
zmq_handle: None,
zmq_token: None,
warning_count: Arc::new(AtomicU32::new(0)),
}
}
pub fn start_background_task(
&mut self,
impl KvEventPublisher {
pub fn new(
component: Component,
worker_id: i64,
zmq_endpoint: String,
zmq_topic: String,
) {
let kv_block_size = self.kv_block_size;
let warning_count = self.warning_count.clone();
let (raw_tx, raw_rx) = mpsc::unbounded_channel::<(u64, Vec<u8>)>();
kv_block_size: usize,
source_config: Option<KvEventSourceConfig>,
) -> Result<Self> {
let cancellation_token = CancellationToken::new();
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let zmq_token = component.rt().child_token();
self.zmq_token = Some(zmq_token.clone());
// Create our event source (if any)
let mut source = None;
if let Some(config) = source_config {
source = Some(KvEventSource::start(
component.clone(),
kv_block_size,
config,
cancellation_token.clone(),
tx.clone(),
)?);
}
// Spawn async ZMQ listener
self.zmq_handle = Some(
component
.drt()
.runtime()
.secondary()
.spawn(start_zmq_listener(
zmq_endpoint,
zmq_topic,
raw_tx,
zmq_token.clone(),
)),
);
self.processor_handle = Some(component.drt().runtime().secondary().spawn(
start_event_processor(
raw_rx,
.spawn(start_event_processor(
component,
worker_id,
kv_block_size,
warning_count,
zmq_token,
),
cancellation_token.clone(),
rx,
));
Ok(Self {
kv_block_size,
source,
cancellation_token,
tx,
})
}
pub fn publish(&self, event: KvCacheEvent) -> Result<(), mpsc::error::SendError<KvCacheEvent>> {
tracing::trace!("Publish event: {:?}", event);
self.tx.send(event)
}
pub fn kv_block_size(&self) -> usize {
self.kv_block_size
}
pub fn shutdown(&mut self) {
if let Some(token) = self.zmq_token.take() {
token.cancel();
if !self.cancellation_token.is_cancelled() {
self.cancellation_token.cancel();
}
if let Some(handle) = self.zmq_handle.take() {
handle.abort();
if let Some(source) = self.source.take() {
source.shutdown();
}
if let Some(handle) = self.processor_handle.take() {
handle.abort();
}
}
impl Drop for KvEventPublisher {
fn drop(&mut self) {
self.shutdown();
}
}
async fn start_event_processor<P: EventPublisher>(
mut raw_rx: mpsc::UnboundedReceiver<(u64, Vec<u8>)>,
component: P,
async fn start_event_processor<P: EventPublisher + Send + Sync + 'static>(
publisher: P,
worker_id: i64,
kv_block_size: usize,
warning_count: Arc<AtomicU32>,
cancellation_token: dynamo_runtime::CancellationToken,
cancellation_token: CancellationToken,
mut rx: mpsc::UnboundedReceiver<KvCacheEvent>,
) {
loop {
tokio::select! {
biased;
// Check for cancellation
_ = cancellation_token.cancelled() => {
tracing::debug!("Event processor received cancellation signal");
tracing::info!("KV Event source received cancellation signal");
break;
}
// Process incoming messages
msg = raw_rx.recv() => {
let Some((seq, payload)) = msg else {
tracing::debug!("Event processor channel closed");
event = rx.recv() => {
let Some(event) = event else {
tracing::debug!("Event processor channel closed.");
break;
};
let batch_result = rmps::from_slice::<KvEventBatch>(&payload);
let Ok(batch) = batch_result else {
let e = batch_result.unwrap_err();
tracing::warn!(error=%e, "Failed to decode KVEventBatch msgpack");
continue;
};
for raw_evt in batch.events.into_iter() {
let Some(event) = convert_event(raw_evt, seq, kv_block_size, &warning_count) else {
// Case where convert_event returns None
continue;
};
// Encapsulate in a router event and publish.
let router_event = RouterEvent::new(worker_id, event);
if let Err(e) = component.publish(KV_EVENT_SUBJECT, &router_event).await {
tracing::warn!(error=%e, "Failed to publish router event.");
}
if let Err(e) = publisher.publish(KV_EVENT_SUBJECT, &router_event).await {
tracing::error!("Failed to publish event: {}", e);
}
}
}
}
tracing::debug!("Event processor exiting");
}
// Error handling configuration for ZMQ operations
......@@ -230,8 +221,9 @@ fn calculate_backoff_ms(consecutive_errors: u32) -> u64 {
async fn start_zmq_listener(
zmq_endpoint: String,
zmq_topic: String,
raw_tx: mpsc::UnboundedSender<(u64, Vec<u8>)>,
zmq_token: dynamo_runtime::CancellationToken,
tx: mpsc::UnboundedSender<KvCacheEvent>,
cancellation_token: CancellationToken,
kv_block_size: usize,
) {
tracing::debug!(
"KVEventPublisher connecting to ZMQ endpoint {} (topic '{}')",
......@@ -239,6 +231,8 @@ async fn start_zmq_listener(
zmq_topic
);
let warning_count = Arc::new(AtomicU32::new(0));
let mut socket = SubSocket::new();
// Subscribe to the requested topic (empty string == all topics)
......@@ -259,7 +253,7 @@ async fn start_zmq_listener(
biased;
// Check for cancellation
_ = zmq_token.cancelled() => {
_ = cancellation_token.cancelled() => {
tracing::info!("ZMQ listener received cancellation signal");
break;
}
......@@ -292,7 +286,6 @@ async fn start_zmq_listener(
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
continue;
};
// Reset error count on successful message
consecutive_errors = 0;
......@@ -303,8 +296,10 @@ async fn start_zmq_listener(
tracing::warn!(expected=3, actual=%frames.len(), "Received unexpected ZMQ frame count");
continue;
}
let payload = frames.remove(2);
let seq_bytes = frames.remove(1);
// Extract the payload and sequence number.
let payload = frames.pop().unwrap();
let seq_bytes = frames.pop().unwrap();
if seq_bytes.len() != 8 {
tracing::warn!(expected=8, actual=%seq_bytes.len(), "Invalid sequence number byte length");
......@@ -312,14 +307,28 @@ async fn start_zmq_listener(
}
let seq = u64::from_be_bytes(seq_bytes.try_into().unwrap());
if raw_tx.send((seq, payload)).is_err() {
// Decode our batch of events.
let batch_result = rmps::from_slice::<KvEventBatch>(&payload);
let Ok(batch) = batch_result else {
let e = batch_result.unwrap_err();
tracing::warn!(error=%e, "Failed to decode KVEventBatch msgpack");
continue;
};
// For each of our events, convert them to [`KvCacheEvent`] and send to the event_processor.
for raw_event in batch.events.into_iter() {
if let Some(event) = convert_event(raw_event, seq, kv_block_size, &warning_count) {
if tx.send(event).is_err() {
tracing::warn!("Failed to send message to channel - receiver dropped");
break;
return;
}
}
}
}
}
tracing::debug!("ZMQ listener exiting");
}
}
/// Convert a raw event coming from the ZMQ channel into the internal
......@@ -630,6 +639,7 @@ mod test_event_processing {
#[cfg(test)]
mod tests_startup_helpers {
use super::*;
use crate::kv_router::protocols::ExternalSequenceBlockHash;
use async_trait;
use bytes::Bytes;
use std::sync::{Arc, Mutex};
......@@ -691,53 +701,35 @@ mod tests_startup_helpers {
}
//--------------------------------------------------------------------
// Test start_event_processor in isolation
// Test start_event_processor
//--------------------------------------------------------------------
#[tokio::test]
async fn test_start_event_processor_sends_router_event() {
let kv_block_size = 4;
let worker_id = 99;
// 1) build a one-item KvEventBatch and msgpack-encode it
let batch = KvEventBatch {
ts: 0.0,
events: vec![RawKvEvent::BlockRemoved {
block_hashes: vec![1, 2],
}],
};
let payload = rmps::to_vec(&batch).unwrap();
async fn test_start_event_processor() {
let (component, published) = MockComponent::new();
let token = dynamo_runtime::CancellationToken::new();
let event = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(1), ExternalSequenceBlockHash(2)],
}),
};
// 2) channel feeding the processor
let (tx, rx) = mpsc::unbounded_channel::<(u64, Vec<u8>)>();
tx.send((123, payload.clone())).unwrap(); // seq = 123
let token = CancellationToken::new();
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
tx.send(event).unwrap();
drop(tx);
// 3) mock component to capture output
let (comp, published) = MockComponent::new();
// 4) run the function under test (let it consume exactly one msg)
let handle = tokio::spawn(start_event_processor(
rx,
comp,
worker_id,
kv_block_size,
Arc::new(AtomicU32::new(0)),
token,
));
let handle = tokio::spawn(start_event_processor(component, 1, token, rx));
tokio::time::timeout(std::time::Duration::from_secs(1), handle)
tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
.await
.unwrap()
.unwrap();
// 5) assert we have exactly one RouterEvent pushed with right worker_id
let published = published.lock().unwrap();
let (subject, bytes) = &published[0];
assert_eq!(published.len(), 1);
let (subject, _) = &published[0];
assert_eq!(subject, &KV_EVENT_SUBJECT.to_string());
assert_eq!(bytes.first(), payload.first())
}
//--------------------------------------------------------------------
......@@ -747,7 +739,7 @@ mod tests_startup_helpers {
#[tokio::test]
async fn test_start_zmq_listener_pushes_to_channel() {
// Prepare channel that listener should fill
let (tx, mut rx) = mpsc::unbounded_channel::<(u64, Vec<u8>)>();
let (tx, mut rx) = mpsc::unbounded_channel::<KvCacheEvent>();
// ZMQ TCP endpoint using localhost with fixed port
let endpoint = "tcp://127.0.0.1:15555";
......@@ -763,7 +755,7 @@ mod tests_startup_helpers {
// Spawn async listener
let listener_handle = tokio::spawn({
let token = token.clone();
start_zmq_listener(endpoint.to_string(), topic, tx, token)
start_zmq_listener(endpoint.to_string(), topic, tx, token, 4)
});
// Give time for the connection to establish
......@@ -771,7 +763,18 @@ mod tests_startup_helpers {
// Send synthetic 3-frame message: [topic, seq(8B), payload]
let seq: u64 = 77;
let payload = Bytes::from("hello");
let events = vec![RawKvEvent::BlockStored {
block_hashes: vec![42],
parent_block_hash: None,
token_ids: vec![0, 1, 2, 3],
block_size: 4,
lora_id: None,
}];
let batch = KvEventBatch { ts: 0.0, events };
let payload = Bytes::from(rmps::to_vec(&batch).unwrap());
let frames = vec![
Bytes::from(""),
......@@ -789,9 +792,19 @@ mod tests_startup_helpers {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Check that we received the message
let (got_seq, got_payload) = rx.try_recv().expect("no message received");
assert_eq!(got_seq, seq);
assert_eq!(got_payload, payload);
let event = rx.try_recv().expect("no message received");
let KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks,
}) = event.data
else {
panic!("expected KvCacheStoreData");
};
assert!(parent_hash.is_none());
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].block_hash.0, 42);
// Stop the listener
token.cancel();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment