Unverified Commit 5a3c52ab authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: more mocker rusty cleanups (#6271)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent a582f4df
...@@ -6,20 +6,19 @@ ...@@ -6,20 +6,19 @@
//! The core mocker logic lives in the `dynamo-mocker` crate. //! The core mocker logic lives in the `dynamo-mocker` crate.
//! This module provides the runtime-dependent engine wrapper. //! This module provides the runtime-dependent engine wrapper.
use std::collections::HashMap; use std::sync::Arc;
use std::sync::{Arc, LazyLock};
use std::time::Duration; use std::time::Duration;
use anyhow::Result; use anyhow::Result;
use dashmap::DashMap;
use futures::StreamExt; use futures::StreamExt;
use rand::Rng; use rand::Rng;
use tokio::sync::{Mutex, OnceCell, mpsc}; use tokio::sync::{Notify, OnceCell, mpsc};
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use uuid::Uuid; use uuid::Uuid;
use dynamo_runtime::DistributedRuntime; use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::config::environment_names::mocker as env_mocker;
use dynamo_runtime::protocols::annotated::Annotated; use dynamo_runtime::protocols::annotated::Annotated;
use dynamo_runtime::{ use dynamo_runtime::{
component::Component, component::Component,
...@@ -35,7 +34,7 @@ use dynamo_kv_router::protocols::KvCacheEvent; ...@@ -35,7 +34,7 @@ use dynamo_kv_router::protocols::KvCacheEvent;
// Re-export from dynamo-mocker for convenience // Re-export from dynamo-mocker for convenience
use dynamo_mocker::bootstrap::{BootstrapServer, connect_to_prefill}; use dynamo_mocker::bootstrap::{BootstrapServer, connect_to_prefill};
use dynamo_mocker::protocols::{OutputSignal, WorkerType}; use dynamo_mocker::protocols::OutputSignal;
pub use dynamo_mocker::{ pub use dynamo_mocker::{
DirectRequest, KvCacheEventSink, MockEngineArgs, MockEngineArgsBuilder, Scheduler, bootstrap, DirectRequest, KvCacheEventSink, MockEngineArgs, MockEngineArgsBuilder, Scheduler, bootstrap,
evictor, kv_manager, perf_model, protocols, running_mean, scheduler, sequence, evictor, kv_manager, perf_model, protocols, running_mean, scheduler, sequence,
...@@ -43,9 +42,6 @@ pub use dynamo_mocker::{ ...@@ -43,9 +42,6 @@ pub use dynamo_mocker::{
pub const MOCKER_COMPONENT: &str = "mocker"; pub const MOCKER_COMPONENT: &str = "mocker";
static MOCKER_DIRECT_SYNC: LazyLock<bool> =
LazyLock::new(|| dynamo_runtime::config::env_is_truthy(env_mocker::DYN_MOCKER_SYNC_DIRECT));
/// Wrapper to adapt KvEventPublisher to the KvCacheEventSink trait /// Wrapper to adapt KvEventPublisher to the KvCacheEventSink trait
struct KvEventSinkAdapter(KvEventPublisher); struct KvEventSinkAdapter(KvEventPublisher);
...@@ -63,10 +59,10 @@ fn generate_random_token() -> TokenIdType { ...@@ -63,10 +59,10 @@ fn generate_random_token() -> TokenIdType {
} }
/// AsyncEngine wrapper around the Scheduler that generates random character tokens /// AsyncEngine wrapper around the Scheduler that generates random character tokens
#[derive(Clone)]
pub struct MockVllmEngine { pub struct MockVllmEngine {
active_requests: Arc<Mutex<HashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>>, active_requests: Arc<DashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>,
request_senders: Arc<OnceCell<Vec<mpsc::UnboundedSender<DirectRequest>>>>, request_senders: OnceCell<Vec<mpsc::UnboundedSender<DirectRequest>>>,
senders_ready: Notify,
engine_args: MockEngineArgs, engine_args: MockEngineArgs,
/// Bootstrap server for prefill workers in disaggregated mode /// Bootstrap server for prefill workers in disaggregated mode
bootstrap_server: Arc<OnceCell<Arc<BootstrapServer>>>, bootstrap_server: Arc<OnceCell<Arc<BootstrapServer>>>,
...@@ -74,11 +70,12 @@ pub struct MockVllmEngine { ...@@ -74,11 +70,12 @@ pub struct MockVllmEngine {
impl MockVllmEngine { impl MockVllmEngine {
/// Create a new MockVllmEngine with the given parameters /// Create a new MockVllmEngine with the given parameters
pub fn new(args: MockEngineArgs) -> Self { pub fn new(engine_args: MockEngineArgs) -> Self {
Self { Self {
active_requests: Arc::new(Mutex::new(HashMap::new())), active_requests: Arc::new(DashMap::new()),
request_senders: Arc::new(OnceCell::new()), request_senders: OnceCell::new(),
engine_args: args, senders_ready: Notify::new(),
engine_args,
bootstrap_server: Arc::new(OnceCell::new()), bootstrap_server: Arc::new(OnceCell::new()),
} }
} }
...@@ -98,7 +95,7 @@ impl MockVllmEngine { ...@@ -98,7 +95,7 @@ impl MockVllmEngine {
} }
// Start bootstrap server for prefill workers in disaggregated mode // Start bootstrap server for prefill workers in disaggregated mode
if self.engine_args.worker_type == WorkerType::Prefill if self.engine_args.is_prefill()
&& let Some(port) = self.engine_args.bootstrap_port && let Some(port) = self.engine_args.bootstrap_port
{ {
let server = BootstrapServer::start(port, cancel_token.clone()).await?; let server = BootstrapServer::start(port, cancel_token.clone()).await?;
...@@ -106,105 +103,79 @@ impl MockVllmEngine { ...@@ -106,105 +103,79 @@ impl MockVllmEngine {
tracing::info!(port = port, "Bootstrap server started for prefill worker"); tracing::info!(port = port, "Bootstrap server started for prefill worker");
} }
// Determine if we need KV event publishers (prefix caching enabled and not decode worker) let kv_component = if self.engine_args.needs_kv_publisher() {
let needs_kv_publisher = self.engine_args.enable_prefix_caching
&& self.engine_args.worker_type != WorkerType::Decode;
if needs_kv_publisher {
tracing::info!( tracing::info!(
"Initializing KV event publisher with block_size {}, enable_local_indexer={}", "Initializing KV event publisher with block_size {}, enable_local_indexer={}",
self.engine_args.block_size, self.engine_args.block_size,
self.engine_args.enable_local_indexer self.engine_args.enable_local_indexer
); );
} Some(&component)
} else {
None
};
let schedulers = self.start_schedulers( let schedulers = self.start_schedulers(kv_component, cancel_token.clone());
self.engine_args.clone(),
self.active_requests.clone(),
if needs_kv_publisher {
Some(component.clone())
} else {
None
},
cancel_token.clone(),
);
Self::start_metrics_publishing(&schedulers, Some(component.clone()), cancel_token.clone()) Self::start_metrics_publishing(&schedulers, component, cancel_token.clone()).await?;
.await?;
Ok(()) Ok(())
} }
/// Send a request to the appropriate scheduler. /// Send a request to the appropriate scheduler, waiting for initialization if needed.
///
/// Set `DYN_MOCKER_SYNC_DIRECT=1` to use the original direct path.
/// - `DYN_MOCKER_SYNC_DIRECT=1` (original, race-condition prone): 922/1000 pass
/// - `DYN_MOCKER_SYNC_DIRECT=0` (use timeout to wait for init): 1000/1000 pass
pub async fn direct(&self, request: DirectRequest, dp_rank: usize) { pub async fn direct(&self, request: DirectRequest, dp_rank: usize) {
// `direct()` can be called before `start_schedulers()` finishes populating if let Some(senders) = self.request_senders.get() {
// `request_senders` under load. The original path panics immediately; the
// default path waits briefly for initialization to complete.
if *MOCKER_DIRECT_SYNC {
let senders = self.request_senders.get().expect("Not initialized");
let _ = senders[dp_rank].send(request); let _ = senders[dp_rank].send(request);
return; return;
} }
// Poll request_senders until initialized (or time out) to avoid the startup race. // Register the waiter *before* re-checking to avoid a TOCTOU race
let start = std::time::Instant::now(); // where `start_schedulers` sets + notifies between our check and subscribe.
loop { let notified = self.senders_ready.notified();
if let Some(senders) = self.request_senders.get() { if let Some(senders) = self.request_senders.get() {
let _ = senders[dp_rank].send(request); let _ = senders[dp_rank].send(request);
return; return;
}
// We can parameterize the timeout to be more flexible.
// For example, on production this could be very short, but in a
// CPU-heavy test environment, this should be very high.
if start.elapsed() > Duration::from_secs(10) {
panic!("Scheduler initialization timed out after 10s");
}
tokio::time::sleep(Duration::from_millis(50)).await;
} }
notified.await;
let senders = self
.request_senders
.get()
.expect("must be set after notify");
let _ = senders[dp_rank].send(request);
} }
/// Create schedulers and spawn their background tasks for distributing token notifications /// Create schedulers and spawn their background tasks for distributing token notifications
fn start_schedulers( fn start_schedulers(
&self, &self,
args: MockEngineArgs, component: Option<&Component>,
active_requests: Arc<Mutex<HashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>>,
component: Option<Component>,
cancel_token: CancellationToken, cancel_token: CancellationToken,
) -> Vec<Scheduler> { ) -> Vec<Scheduler> {
let args = &self.engine_args;
let mut schedulers = Vec::<Scheduler>::new(); let mut schedulers = Vec::<Scheduler>::new();
let mut senders = Vec::with_capacity(args.dp_size as usize); let mut senders = Vec::with_capacity(args.dp_size as usize);
// Create multiple schedulers and their background tasks
for dp_rank in 0..args.dp_size { for dp_rank in 0..args.dp_size {
// Create a shared output channel that this scheduler will use
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>(); let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
// Create a KvEventPublisher for THIS dp_rank if component is provided let kv_event_sink: Option<Arc<dyn KvCacheEventSink>> = component.and_then(|comp| {
let kv_event_sink: Option<Arc<dyn KvCacheEventSink>> = match KvEventPublisher::new_with_local_indexer(
component comp.clone(),
.as_ref() args.block_size as u32,
.and_then(|comp| { None,
match KvEventPublisher::new_with_local_indexer( args.enable_local_indexer,
comp.clone(), dp_rank,
args.block_size as u32, ) {
None, Ok(publisher) => {
args.enable_local_indexer, Some(Arc::new(KvEventSinkAdapter(publisher)) as Arc<dyn KvCacheEventSink>)
dp_rank, }
) { Err(e) => {
Ok(publisher) => Some(Arc::new(KvEventSinkAdapter(publisher)) tracing::error!(
as Arc<dyn KvCacheEventSink>), "Failed to create KV event publisher for dp_rank {dp_rank}: {e}"
Err(e) => { );
tracing::error!( None
"Failed to create KV event publisher for dp_rank {dp_rank}: {e}" }
); }
None });
}
}
});
let scheduler = Scheduler::new( let scheduler = Scheduler::new(
args.clone(), args.clone(),
...@@ -217,9 +188,7 @@ impl MockVllmEngine { ...@@ -217,9 +188,7 @@ impl MockVllmEngine {
senders.push(scheduler.request_sender()); senders.push(scheduler.request_sender());
schedulers.push(scheduler); schedulers.push(scheduler);
// Spawn a background task for this scheduler to distribute token notifications to active requests let active_requests_clone = self.active_requests.clone();
// let output_rx = Arc::new(Mutex::new(output_rx));
let active_requests_clone = active_requests.clone();
let cancel_token_cloned = cancel_token.clone(); let cancel_token_cloned = cancel_token.clone();
tokio::spawn(async move { tokio::spawn(async move {
...@@ -230,18 +199,13 @@ impl MockVllmEngine { ...@@ -230,18 +199,13 @@ impl MockVllmEngine {
break; // Channel closed break; // Channel closed
}; };
// Notify the specific request that a token was generated if let Some(request_tx) = active_requests_clone.get(&signal.uuid) {
let active = active_requests_clone.lock().await;
if let Some(request_tx) = active.get(&signal.uuid) {
let _ = request_tx.send(signal); let _ = request_tx.send(signal);
} }
} }
_ = cancel_token_cloned.cancelled() => { _ = cancel_token_cloned.cancelled() => {
tracing::info!("Scheduler output task cancelled, clearing active requests"); tracing::info!("Scheduler output task cancelled, clearing active requests");
// Clear all active requests to unblock waiting request handlers active_requests_clone.clear();
// This will cause their request_rx.recv() to return None
let mut active = active_requests_clone.lock().await;
active.clear();
break; break;
} }
} }
...@@ -249,10 +213,11 @@ impl MockVllmEngine { ...@@ -249,10 +213,11 @@ impl MockVllmEngine {
}); });
} }
// Set the senders once // Set the senders once and notify waiters
self.request_senders self.request_senders
.set(senders) .set(senders)
.expect("Already initialized"); .expect("Already initialized");
self.senders_ready.notify_waiters();
schedulers schedulers
} }
...@@ -260,30 +225,14 @@ impl MockVllmEngine { ...@@ -260,30 +225,14 @@ impl MockVllmEngine {
/// Start background tasks to publish metrics on change /// Start background tasks to publish metrics on change
async fn start_metrics_publishing( async fn start_metrics_publishing(
schedulers: &[Scheduler], schedulers: &[Scheduler],
component: Option<Component>, component: Component,
cancel_token: CancellationToken, cancel_token: CancellationToken,
) -> Result<()> { ) -> Result<()> {
tracing::debug!("Creating metrics publisher");
let metrics_publisher = Arc::new(WorkerMetricsPublisher::new()?); let metrics_publisher = Arc::new(WorkerMetricsPublisher::new()?);
tracing::debug!("Metrics publisher created");
if let Some(comp) = component {
tracing::debug!("Creating metrics endpoint");
tokio::spawn({
let publisher = metrics_publisher.clone();
async move {
if let Err(e) = publisher.create_endpoint(comp.clone()).await {
tracing::error!("Metrics endpoint failed: {e}");
}
}
});
// Give it a moment to start if let Err(e) = metrics_publisher.create_endpoint(component).await {
tokio::time::sleep(Duration::from_millis(100)).await; tracing::error!("Metrics endpoint failed: {e}");
tracing::debug!("Metrics endpoint started (background)");
} }
tracing::debug!("Starting metrics background tasks");
for scheduler in schedulers.iter() { for scheduler in schedulers.iter() {
let mut metrics_rx = scheduler.metrics_receiver(); let mut metrics_rx = scheduler.metrics_receiver();
let publisher = metrics_publisher.clone(); let publisher = metrics_publisher.clone();
...@@ -347,7 +296,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -347,7 +296,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
// - Prefill: complete_room() is called after first token (see below) // - Prefill: complete_room() is called after first token (see below)
let bootstrap_room = request.bootstrap_info.as_ref().map(|b| b.bootstrap_room); let bootstrap_room = request.bootstrap_info.as_ref().map(|b| b.bootstrap_room);
if let Some(bootstrap_info) = &request.bootstrap_info if let Some(bootstrap_info) = &request.bootstrap_info
&& self.engine_args.worker_type == WorkerType::Decode && self.engine_args.is_decode()
{ {
connect_to_prefill( connect_to_prefill(
&bootstrap_info.bootstrap_host, &bootstrap_info.bootstrap_host,
...@@ -360,15 +309,15 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -360,15 +309,15 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
let request_uuid = ctx.id().parse().unwrap_or(Uuid::new_v4()); let request_uuid = ctx.id().parse().unwrap_or(Uuid::new_v4());
// For prefill workers, override max_tokens to 1 let is_prefill = self.engine_args.is_prefill();
let is_prefill = self.engine_args.worker_type == WorkerType::Prefill;
let max_output_tokens = if is_prefill { let max_output_tokens = if is_prefill {
1 1
} else { } else {
request request
.stop_conditions .stop_conditions
.max_tokens .max_tokens
.expect("max_output_tokens must be specified for mocker") as usize .ok_or_else(|| Error::msg("max_output_tokens must be specified for mocker"))?
as usize
}; };
// Convert PreprocessedRequest to DirectRequest for scheduler // Convert PreprocessedRequest to DirectRequest for scheduler
...@@ -380,10 +329,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -380,10 +329,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
}; };
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<OutputSignal>(); let (request_tx, mut request_rx) = mpsc::unbounded_channel::<OutputSignal>();
{ self.active_requests.insert(request_uuid, request_tx);
let mut active = self.active_requests.lock().await;
active.insert(request_uuid, request_tx);
}
// Send the request to the appropriate scheduler based on dp_rank // Send the request to the appropriate scheduler based on dp_rank
self.direct(direct_request, dp_rank as usize).await; self.direct(direct_request, dp_rank as usize).await;
...@@ -413,24 +359,8 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -413,24 +359,8 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
let output = LLMEngineOutput { let output = LLMEngineOutput {
token_ids: vec![token_id], token_ids: vec![token_id],
tokens: None, // Let backend handle detokenization disaggregated_params: is_prefill.then(|| serde_json::json!("dummy")),
text: None, ..Default::default()
output_type: Default::default(),
content_parts: None,
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: None,
stop_reason: None,
index: None,
// Add dummy disaggregated_params for prefill workers
disaggregated_params: if is_prefill {
Some(serde_json::json!("dummy"))
} else {
None
},
extra_args: None,
completion_usage: None,
}; };
// Prefill: after first token, mark room complete (unblocks decode) // Prefill: after first token, mark room complete (unblocks decode)
...@@ -465,12 +395,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -465,12 +395,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
} }
} }
// Clean up: remove this request from active requests active_requests.remove(&request_uuid);
let mut active = active_requests.lock().await;
active.remove(&request_uuid);
}); });
// Create a simple UnboundedReceiverStream which is naturally Send + Sync
let stream = UnboundedReceiverStream::new(stream_rx); let stream = UnboundedReceiverStream::new(stream_rx);
Ok(ResponseStream::new(Box::pin(stream), ctx.context())) Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
} }
...@@ -490,41 +417,33 @@ impl AnnotatedMockEngine { ...@@ -490,41 +417,33 @@ impl AnnotatedMockEngine {
let inner_clone = inner.clone(); let inner_clone = inner.clone();
// Start background task to wait for component service and start the engine // Start background task to wait for component service and start the engine
let cancel_token = distributed_runtime.primary_token();
tokio::spawn(async move { tokio::spawn(async move {
loop { let component = loop {
// Try to create component if cancel_token.is_cancelled() {
let Ok(namespace) = distributed_runtime.namespace(&endpoint_id.namespace) else { tracing::debug!("Mocker engine startup cancelled");
tracing::debug!("Namespace not available yet, retrying..."); return;
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
};
let Ok(component) = namespace.component(&endpoint_id.component) else {
tracing::debug!("Component not available yet, retrying...");
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
};
// Check if service is available by trying to list instances
let Ok(instances) = component.list_instances().await else {
tracing::debug!("Cannot list instances yet, retrying...");
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
};
if instances.is_empty() {
tracing::debug!("No instances available yet, retrying...");
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
} }
tracing::debug!("Component service is now available, starting mocker engine"); let ready = distributed_runtime
.namespace(&endpoint_id.namespace)
.and_then(|ns| ns.component(&endpoint_id.component))
.ok();
// Start the engine with the component if let Some(comp) = ready
if let Err(e) = inner_clone.start(component).await { && let Ok(instances) = comp.list_instances().await
tracing::error!("Failed to start mocker engine: {e}"); && !instances.is_empty()
{
break comp;
} }
break;
tracing::debug!("Component service not available yet, retrying...");
tokio::time::sleep(Duration::from_millis(100)).await;
};
tracing::debug!("Component service is now available, starting mocker engine");
if let Err(e) = inner_clone.start(component).await {
tracing::error!("Failed to start mocker engine: {e}");
} }
}); });
......
...@@ -161,6 +161,18 @@ impl MockEngineArgs { ...@@ -161,6 +161,18 @@ impl MockEngineArgs {
MockEngineArgsBuilder::default() MockEngineArgsBuilder::default()
} }
pub fn is_prefill(&self) -> bool {
self.worker_type == WorkerType::Prefill
}
pub fn is_decode(&self) -> bool {
self.worker_type == WorkerType::Decode
}
pub fn needs_kv_publisher(&self) -> bool {
self.enable_prefix_caching && !self.is_decode()
}
/// Create MockEngineArgs from a JSON file containing extra engine arguments /// Create MockEngineArgs from a JSON file containing extra engine arguments
pub fn from_json_file(path: &Path) -> anyhow::Result<Self> { pub fn from_json_file(path: &Path) -> anyhow::Result<Self> {
let mut builder = Self::builder(); let mut builder = Self::builder();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment