Unverified Commit 5a3c52ab authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: more mocker rusty cleanups (#6271)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent a582f4df
......@@ -6,20 +6,19 @@
//! The core mocker logic lives in the `dynamo-mocker` crate.
//! This module provides the runtime-dependent engine wrapper.
use std::collections::HashMap;
use std::sync::{Arc, LazyLock};
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use dashmap::DashMap;
use futures::StreamExt;
use rand::Rng;
use tokio::sync::{Mutex, OnceCell, mpsc};
use tokio::sync::{Notify, OnceCell, mpsc};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::config::environment_names::mocker as env_mocker;
use dynamo_runtime::protocols::annotated::Annotated;
use dynamo_runtime::{
component::Component,
......@@ -35,7 +34,7 @@ use dynamo_kv_router::protocols::KvCacheEvent;
// Re-export from dynamo-mocker for convenience
use dynamo_mocker::bootstrap::{BootstrapServer, connect_to_prefill};
use dynamo_mocker::protocols::{OutputSignal, WorkerType};
use dynamo_mocker::protocols::OutputSignal;
pub use dynamo_mocker::{
DirectRequest, KvCacheEventSink, MockEngineArgs, MockEngineArgsBuilder, Scheduler, bootstrap,
evictor, kv_manager, perf_model, protocols, running_mean, scheduler, sequence,
......@@ -43,9 +42,6 @@ pub use dynamo_mocker::{
pub const MOCKER_COMPONENT: &str = "mocker";
static MOCKER_DIRECT_SYNC: LazyLock<bool> =
LazyLock::new(|| dynamo_runtime::config::env_is_truthy(env_mocker::DYN_MOCKER_SYNC_DIRECT));
/// Wrapper to adapt KvEventPublisher to the KvCacheEventSink trait
struct KvEventSinkAdapter(KvEventPublisher);
......@@ -63,10 +59,10 @@ fn generate_random_token() -> TokenIdType {
}
/// AsyncEngine wrapper around the Scheduler that generates random character tokens
#[derive(Clone)]
pub struct MockVllmEngine {
active_requests: Arc<Mutex<HashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>>,
request_senders: Arc<OnceCell<Vec<mpsc::UnboundedSender<DirectRequest>>>>,
active_requests: Arc<DashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>,
request_senders: OnceCell<Vec<mpsc::UnboundedSender<DirectRequest>>>,
senders_ready: Notify,
engine_args: MockEngineArgs,
/// Bootstrap server for prefill workers in disaggregated mode
bootstrap_server: Arc<OnceCell<Arc<BootstrapServer>>>,
......@@ -74,11 +70,12 @@ pub struct MockVllmEngine {
impl MockVllmEngine {
/// Create a new MockVllmEngine with the given parameters
pub fn new(args: MockEngineArgs) -> Self {
pub fn new(engine_args: MockEngineArgs) -> Self {
Self {
active_requests: Arc::new(Mutex::new(HashMap::new())),
request_senders: Arc::new(OnceCell::new()),
engine_args: args,
active_requests: Arc::new(DashMap::new()),
request_senders: OnceCell::new(),
senders_ready: Notify::new(),
engine_args,
bootstrap_server: Arc::new(OnceCell::new()),
}
}
......@@ -98,7 +95,7 @@ impl MockVllmEngine {
}
// Start bootstrap server for prefill workers in disaggregated mode
if self.engine_args.worker_type == WorkerType::Prefill
if self.engine_args.is_prefill()
&& let Some(port) = self.engine_args.bootstrap_port
{
let server = BootstrapServer::start(port, cancel_token.clone()).await?;
......@@ -106,88 +103,61 @@ impl MockVllmEngine {
tracing::info!(port = port, "Bootstrap server started for prefill worker");
}
// Determine if we need KV event publishers (prefix caching enabled and not decode worker)
let needs_kv_publisher = self.engine_args.enable_prefix_caching
&& self.engine_args.worker_type != WorkerType::Decode;
if needs_kv_publisher {
let kv_component = if self.engine_args.needs_kv_publisher() {
tracing::info!(
"Initializing KV event publisher with block_size {}, enable_local_indexer={}",
self.engine_args.block_size,
self.engine_args.enable_local_indexer
);
}
let schedulers = self.start_schedulers(
self.engine_args.clone(),
self.active_requests.clone(),
if needs_kv_publisher {
Some(component.clone())
Some(&component)
} else {
None
},
cancel_token.clone(),
);
};
Self::start_metrics_publishing(&schedulers, Some(component.clone()), cancel_token.clone())
.await?;
let schedulers = self.start_schedulers(kv_component, cancel_token.clone());
Self::start_metrics_publishing(&schedulers, component, cancel_token.clone()).await?;
Ok(())
}
/// Send a request to the appropriate scheduler.
///
/// Set `DYN_MOCKER_SYNC_DIRECT=1` to use the original direct path.
/// - `DYN_MOCKER_SYNC_DIRECT=1` (original, race-condition prone): 922/1000 pass
/// - `DYN_MOCKER_SYNC_DIRECT=0` (use timeout to wait for init): 1000/1000 pass
/// Send a request to the appropriate scheduler, waiting for initialization if needed.
pub async fn direct(&self, request: DirectRequest, dp_rank: usize) {
// `direct()` can be called before `start_schedulers()` finishes populating
// `request_senders` under load. The original path panics immediately; the
// default path waits briefly for initialization to complete.
if *MOCKER_DIRECT_SYNC {
let senders = self.request_senders.get().expect("Not initialized");
if let Some(senders) = self.request_senders.get() {
let _ = senders[dp_rank].send(request);
return;
}
// Poll request_senders until initialized (or time out) to avoid the startup race.
let start = std::time::Instant::now();
loop {
// Register the waiter *before* re-checking to avoid a TOCTOU race
// where `start_schedulers` sets + notifies between our check and subscribe.
let notified = self.senders_ready.notified();
if let Some(senders) = self.request_senders.get() {
let _ = senders[dp_rank].send(request);
return;
}
// We can parameterize the timeout to be more flexible.
// For example, on production this could be very short, but in a
// CPU-heavy test environment, this should be very high.
if start.elapsed() > Duration::from_secs(10) {
panic!("Scheduler initialization timed out after 10s");
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
notified.await;
let senders = self
.request_senders
.get()
.expect("must be set after notify");
let _ = senders[dp_rank].send(request);
}
/// Create schedulers and spawn their background tasks for distributing token notifications
fn start_schedulers(
&self,
args: MockEngineArgs,
active_requests: Arc<Mutex<HashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>>,
component: Option<Component>,
component: Option<&Component>,
cancel_token: CancellationToken,
) -> Vec<Scheduler> {
let args = &self.engine_args;
let mut schedulers = Vec::<Scheduler>::new();
let mut senders = Vec::with_capacity(args.dp_size as usize);
// Create multiple schedulers and their background tasks
for dp_rank in 0..args.dp_size {
// Create a shared output channel that this scheduler will use
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
// Create a KvEventPublisher for THIS dp_rank if component is provided
let kv_event_sink: Option<Arc<dyn KvCacheEventSink>> =
component
.as_ref()
.and_then(|comp| {
let kv_event_sink: Option<Arc<dyn KvCacheEventSink>> = component.and_then(|comp| {
match KvEventPublisher::new_with_local_indexer(
comp.clone(),
args.block_size as u32,
......@@ -195,8 +165,9 @@ impl MockVllmEngine {
args.enable_local_indexer,
dp_rank,
) {
Ok(publisher) => Some(Arc::new(KvEventSinkAdapter(publisher))
as Arc<dyn KvCacheEventSink>),
Ok(publisher) => {
Some(Arc::new(KvEventSinkAdapter(publisher)) as Arc<dyn KvCacheEventSink>)
}
Err(e) => {
tracing::error!(
"Failed to create KV event publisher for dp_rank {dp_rank}: {e}"
......@@ -217,9 +188,7 @@ impl MockVllmEngine {
senders.push(scheduler.request_sender());
schedulers.push(scheduler);
// Spawn a background task for this scheduler to distribute token notifications to active requests
// let output_rx = Arc::new(Mutex::new(output_rx));
let active_requests_clone = active_requests.clone();
let active_requests_clone = self.active_requests.clone();
let cancel_token_cloned = cancel_token.clone();
tokio::spawn(async move {
......@@ -230,18 +199,13 @@ impl MockVllmEngine {
break; // Channel closed
};
// Notify the specific request that a token was generated
let active = active_requests_clone.lock().await;
if let Some(request_tx) = active.get(&signal.uuid) {
if let Some(request_tx) = active_requests_clone.get(&signal.uuid) {
let _ = request_tx.send(signal);
}
}
_ = cancel_token_cloned.cancelled() => {
tracing::info!("Scheduler output task cancelled, clearing active requests");
// Clear all active requests to unblock waiting request handlers
// This will cause their request_rx.recv() to return None
let mut active = active_requests_clone.lock().await;
active.clear();
active_requests_clone.clear();
break;
}
}
......@@ -249,10 +213,11 @@ impl MockVllmEngine {
});
}
// Set the senders once
// Set the senders once and notify waiters
self.request_senders
.set(senders)
.expect("Already initialized");
self.senders_ready.notify_waiters();
schedulers
}
......@@ -260,30 +225,14 @@ impl MockVllmEngine {
/// Start background tasks to publish metrics on change
async fn start_metrics_publishing(
schedulers: &[Scheduler],
component: Option<Component>,
component: Component,
cancel_token: CancellationToken,
) -> Result<()> {
tracing::debug!("Creating metrics publisher");
let metrics_publisher = Arc::new(WorkerMetricsPublisher::new()?);
tracing::debug!("Metrics publisher created");
if let Some(comp) = component {
tracing::debug!("Creating metrics endpoint");
tokio::spawn({
let publisher = metrics_publisher.clone();
async move {
if let Err(e) = publisher.create_endpoint(comp.clone()).await {
if let Err(e) = metrics_publisher.create_endpoint(component).await {
tracing::error!("Metrics endpoint failed: {e}");
}
}
});
// Give it a moment to start
tokio::time::sleep(Duration::from_millis(100)).await;
tracing::debug!("Metrics endpoint started (background)");
}
tracing::debug!("Starting metrics background tasks");
for scheduler in schedulers.iter() {
let mut metrics_rx = scheduler.metrics_receiver();
let publisher = metrics_publisher.clone();
......@@ -347,7 +296,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
// - Prefill: complete_room() is called after first token (see below)
let bootstrap_room = request.bootstrap_info.as_ref().map(|b| b.bootstrap_room);
if let Some(bootstrap_info) = &request.bootstrap_info
&& self.engine_args.worker_type == WorkerType::Decode
&& self.engine_args.is_decode()
{
connect_to_prefill(
&bootstrap_info.bootstrap_host,
......@@ -360,15 +309,15 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
let request_uuid = ctx.id().parse().unwrap_or(Uuid::new_v4());
// For prefill workers, override max_tokens to 1
let is_prefill = self.engine_args.worker_type == WorkerType::Prefill;
let is_prefill = self.engine_args.is_prefill();
let max_output_tokens = if is_prefill {
1
} else {
request
.stop_conditions
.max_tokens
.expect("max_output_tokens must be specified for mocker") as usize
.ok_or_else(|| Error::msg("max_output_tokens must be specified for mocker"))?
as usize
};
// Convert PreprocessedRequest to DirectRequest for scheduler
......@@ -380,10 +329,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
};
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<OutputSignal>();
{
let mut active = self.active_requests.lock().await;
active.insert(request_uuid, request_tx);
}
self.active_requests.insert(request_uuid, request_tx);
// Send the request to the appropriate scheduler based on dp_rank
self.direct(direct_request, dp_rank as usize).await;
......@@ -413,24 +359,8 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
let output = LLMEngineOutput {
token_ids: vec![token_id],
tokens: None, // Let backend handle detokenization
text: None,
output_type: Default::default(),
content_parts: None,
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: None,
stop_reason: None,
index: None,
// Add dummy disaggregated_params for prefill workers
disaggregated_params: if is_prefill {
Some(serde_json::json!("dummy"))
} else {
None
},
extra_args: None,
completion_usage: None,
disaggregated_params: is_prefill.then(|| serde_json::json!("dummy")),
..Default::default()
};
// Prefill: after first token, mark room complete (unblocks decode)
......@@ -465,12 +395,9 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
}
}
// Clean up: remove this request from active requests
let mut active = active_requests.lock().await;
active.remove(&request_uuid);
active_requests.remove(&request_uuid);
});
// Create a simple UnboundedReceiverStream which is naturally Send + Sync
let stream = UnboundedReceiverStream::new(stream_rx);
Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
}
......@@ -490,42 +417,34 @@ impl AnnotatedMockEngine {
let inner_clone = inner.clone();
// Start background task to wait for component service and start the engine
let cancel_token = distributed_runtime.primary_token();
tokio::spawn(async move {
loop {
// Try to create component
let Ok(namespace) = distributed_runtime.namespace(&endpoint_id.namespace) else {
tracing::debug!("Namespace not available yet, retrying...");
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
};
let component = loop {
if cancel_token.is_cancelled() {
tracing::debug!("Mocker engine startup cancelled");
return;
}
let Ok(component) = namespace.component(&endpoint_id.component) else {
tracing::debug!("Component not available yet, retrying...");
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
};
let ready = distributed_runtime
.namespace(&endpoint_id.namespace)
.and_then(|ns| ns.component(&endpoint_id.component))
.ok();
// Check if service is available by trying to list instances
let Ok(instances) = component.list_instances().await else {
tracing::debug!("Cannot list instances yet, retrying...");
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
};
if let Some(comp) = ready
&& let Ok(instances) = comp.list_instances().await
&& !instances.is_empty()
{
break comp;
}
if instances.is_empty() {
tracing::debug!("No instances available yet, retrying...");
tracing::debug!("Component service not available yet, retrying...");
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
}
};
tracing::debug!("Component service is now available, starting mocker engine");
// Start the engine with the component
if let Err(e) = inner_clone.start(component).await {
tracing::error!("Failed to start mocker engine: {e}");
}
break;
}
});
Self { inner }
......
......@@ -161,6 +161,18 @@ impl MockEngineArgs {
MockEngineArgsBuilder::default()
}
pub fn is_prefill(&self) -> bool {
self.worker_type == WorkerType::Prefill
}
pub fn is_decode(&self) -> bool {
self.worker_type == WorkerType::Decode
}
pub fn needs_kv_publisher(&self) -> bool {
self.enable_prefix_caching && !self.is_decode()
}
/// Create MockEngineArgs from a JSON file containing extra engine arguments
pub fn from_json_file(path: &Path) -> anyhow::Result<Self> {
let mut builder = Self::builder();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment