Unverified Commit 1c39e483 authored by Sean SH Choi's avatar Sean SH Choi Committed by GitHub
Browse files

feat: add frontend based prefill request routing for sglang (#4635)


Signed-off-by: default avatarSean Choi <sechoi@nvidia.com>
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarYan Ru Pei <yanrpei@gmail.com>
Co-authored-by: default avatarishandhanani <82981111+ishandhanani@users.noreply.github.com>
parent 7a5f70f3
...@@ -137,22 +137,6 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -137,22 +137,6 @@ async def init(runtime: DistributedRuntime, config: Config):
"Registered engine routes: /engine/start_profile, /engine/stop_profile" "Registered engine routes: /engine/start_profile, /engine/stop_profile"
) )
prefill_client = None
prefill_router_client = None
if config.serving_mode == DisaggregationMode.DECODE:
prefill_router_client = (
await runtime.namespace(dynamo_args.namespace)
.component("router")
.endpoint("best_worker_id")
.client()
)
prefill_client = (
await runtime.namespace(dynamo_args.namespace)
.component("prefill")
.endpoint("generate")
.client()
)
# publisher instantiates the metrics and kv event publishers # publisher instantiates the metrics and kv event publishers
publisher, metrics_task, metrics_labels = await setup_sgl_metrics( publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint engine, config, component, generate_endpoint
...@@ -165,9 +149,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -165,9 +149,7 @@ async def init(runtime: DistributedRuntime, config: Config):
# Readiness gate: requests wait until model is registered # Readiness gate: requests wait until model is registered
ready_event = asyncio.Event() ready_event = asyncio.Event()
handler = DecodeWorkerHandler( handler = DecodeWorkerHandler(component, engine, config, publisher)
component, engine, config, publisher, prefill_client, prefill_router_client
)
print(f"Config: {config}") print(f"Config: {config}")
health_check_payload = SglangHealthCheckPayload( health_check_payload = SglangHealthCheckPayload(
engine, use_text_input=dynamo_args.use_sglang_tokenizer engine, use_text_input=dynamo_args.use_sglang_tokenizer
...@@ -272,17 +254,29 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -272,17 +254,29 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict() health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict()
tasks = [ # Readiness gate: requests wait until model is registered
generate_endpoint.serve_endpoint( ready_event = asyncio.Event()
handler.generate,
graceful_shutdown=True,
metrics_labels=metrics_labels,
health_check_payload=health_check_payload,
)
]
try: try:
await asyncio.gather(*tasks) # Start endpoint immediately and register model concurrently
# Registration publishes runtime_config with bootstrap endpoint for optimization
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=metrics_labels,
health_check_payload=health_check_payload,
),
register_llm_with_readiness_gate(
engine,
generate_endpoint,
server_args,
dynamo_args,
input_type=ModelInput.Tokens,
output_type=ModelType.Prefill,
readiness_gate=ready_event,
),
)
except Exception as e: except Exception as e:
logging.error(f"Failed to serve endpoints: {e}") logging.error(f"Failed to serve endpoints: {e}")
raise raise
......
...@@ -3,10 +3,12 @@ ...@@ -3,10 +3,12 @@
import asyncio import asyncio
import logging import logging
import socket
from typing import Optional from typing import Optional
import sglang as sgl import sglang as sgl
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_local_ip_auto
from dynamo._core import Endpoint from dynamo._core import Endpoint
from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm
...@@ -65,6 +67,39 @@ async def _register_llm_with_runtime_config( ...@@ -65,6 +67,39 @@ async def _register_llm_with_runtime_config(
return False return False
def _get_bootstrap_info_for_config(
engine: sgl.Engine,
) -> tuple[Optional[str], Optional[int]]:
"""Extract bootstrap host and port from SGLang engine for config registration.
Args:
engine: The SGLang engine instance.
Returns:
Tuple of (bootstrap_host, bootstrap_port), or (None, None) if not available.
"""
try:
inner_tm = engine.tokenizer_manager
bootstrap_port = getattr(
inner_tm.server_args, "disaggregation_bootstrap_port", None
)
if bootstrap_port is None:
return None, None
if inner_tm.server_args.dist_init_addr:
bootstrap_host = socket.gethostbyname(
inner_tm.server_args.dist_init_addr.split(":")[0]
)
else:
bootstrap_host = get_local_ip_auto()
return bootstrap_host, bootstrap_port
except Exception as e:
logging.warning(f"Failed to get bootstrap info: {e}")
return None, None
async def _get_runtime_config( async def _get_runtime_config(
engine: sgl.Engine, server_args: ServerArgs, dynamo_args: DynamoArgs engine: sgl.Engine, server_args: ServerArgs, dynamo_args: DynamoArgs
) -> Optional[ModelRuntimeConfig]: ) -> Optional[ModelRuntimeConfig]:
...@@ -84,6 +119,14 @@ async def _get_runtime_config( ...@@ -84,6 +119,14 @@ async def _get_runtime_config(
runtime_config.tool_call_parser = dynamo_args.tool_call_parser runtime_config.tool_call_parser = dynamo_args.tool_call_parser
runtime_config.enable_local_indexer = dynamo_args.enable_local_indexer runtime_config.enable_local_indexer = dynamo_args.enable_local_indexer
# Set bootstrap endpoint for disaggregated serving (prefill workers)
bootstrap_host, bootstrap_port = _get_bootstrap_info_for_config(engine)
if bootstrap_host and bootstrap_port:
runtime_config.set_disaggregated_endpoint(bootstrap_host, bootstrap_port)
logging.info(
f"Publishing disaggregated endpoint to discovery: "
f"{bootstrap_host}:{bootstrap_port}"
)
# In SGLang, these are server_args, not scheduler_info (unlike vLLM) # In SGLang, these are server_args, not scheduler_info (unlike vLLM)
# Note: If --max-running-requests is not specified, SGLang uses an internal default # Note: If --max-running-requests is not specified, SGLang uses an internal default
# undocumented value. The value here will be None if not explicitly set by user. # undocumented value. The value here will be None if not explicitly set by user.
......
...@@ -4,16 +4,18 @@ ...@@ -4,16 +4,18 @@
import asyncio import asyncio
import logging import logging
import time import time
from typing import Any, AsyncGenerator, Dict, Optional from typing import Any, AsyncGenerator, Dict
import sglang as sgl import sglang as sgl
from dynamo._core import Client, Component, Context from dynamo._core import Component, Context
from dynamo.sglang.args import Config, DisaggregationMode from dynamo.sglang.args import Config, DisaggregationMode
from dynamo.sglang.protocol import DisaggPreprocessedRequest
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
# Timeout for decode engine to receive first response when waiting for KV cache transfer
DECODE_KV_TRANSFER_TIMEOUT_SECONDS = 60.0
class DecodeWorkerHandler(BaseWorkerHandler): class DecodeWorkerHandler(BaseWorkerHandler):
"""Handler for decode workers in both aggregated and disaggregated serving modes.""" """Handler for decode workers in both aggregated and disaggregated serving modes."""
...@@ -24,8 +26,6 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -24,8 +26,6 @@ class DecodeWorkerHandler(BaseWorkerHandler):
engine: sgl.Engine, engine: sgl.Engine,
config: Config, config: Config,
publisher: DynamoSglangPublisher, publisher: DynamoSglangPublisher,
prefill_client: Optional[Client] = None,
prefill_router_client: Optional[Client] = None,
) -> None: ) -> None:
"""Initialize decode worker handler. """Initialize decode worker handler.
...@@ -34,29 +34,19 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -34,29 +34,19 @@ class DecodeWorkerHandler(BaseWorkerHandler):
engine: The SGLang engine instance. engine: The SGLang engine instance.
config: SGLang and Dynamo configuration. config: SGLang and Dynamo configuration.
publisher: Metrics publisher for the worker. publisher: Metrics publisher for the worker.
prefill_client: Optional client for prefill worker in disaggregated mode.
prefill_router_client: Optional client for prefill router in disaggregated mode.
Raises:
ValueError: If prefill_client is not provided in decode serving mode.
""" """
super().__init__( super().__init__(
component, component,
engine, engine,
config, config,
publisher, publisher,
prefill_client,
) )
if self.serving_mode == DisaggregationMode.DECODE: if self.serving_mode == DisaggregationMode.DECODE:
if self.prefill_client is None: logging.info(
raise ValueError( "Decode worker handler initialized (disaggregated decode mode)"
"prefill_client must be provided when serving_mode is decode" )
) else:
self.prefill_client = prefill_client logging.info("Decode worker handler initialized (aggregated mode)")
logging.info("Decode worker handler initialized")
self.prefill_router_client = prefill_router_client
logging.info("Worker handler initialized")
def cleanup(self) -> None: def cleanup(self) -> None:
"""Shutdown the engine and cleanup resources.""" """Shutdown the engine and cleanup resources."""
...@@ -117,43 +107,20 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -117,43 +107,20 @@ class DecodeWorkerHandler(BaseWorkerHandler):
input_param = self._get_input_param(request) input_param = self._get_input_param(request)
if self.serving_mode == DisaggregationMode.DECODE: if self.serving_mode == DisaggregationMode.DECODE:
# request the bootstrap info from the target prefill worker # Check if bootstrap_info is in the request
if ( bootstrap_info = request.get("bootstrap_info")
self.prefill_router_client is not None
and self.prefill_router_client.instance_ids()
):
token_ids = request["token_ids"]
stream = await self.prefill_router_client.generate(token_ids)
result = await anext(stream)
(
worker_id,
overlap,
) = result.data() # Returns tuple (worker_id, overlap_amount)
logging.info(f"Best prefill worker ID: {worker_id}, overlap: {overlap}")
prefill_stream = await self.prefill_client.direct(
DisaggPreprocessedRequest(
request=request,
sampling_params=sampling_params,
).model_dump(),
worker_id,
)
else:
prefill_stream = await self.prefill_client.generate(
DisaggPreprocessedRequest(
request=request,
sampling_params=sampling_params,
).model_dump(),
context=context,
)
bootstrap_info = None
async for info in prefill_stream:
bootstrap_info = info.data()
break
if not bootstrap_info: if not bootstrap_info:
raise RuntimeError("No bootstrap info received from prefill worker") raise RuntimeError(
"bootstrap_info is required for disaggregated decode but was not provided."
)
logging.debug(
f"Using bootstrap_info: "
f"host={bootstrap_info['bootstrap_host']}, "
f"port={bootstrap_info['bootstrap_port']}, "
f"room={bootstrap_info['bootstrap_room']}"
)
if self.enable_trace: if self.enable_trace:
self._propagate_trace_context_to_sglang( self._propagate_trace_context_to_sglang(
...@@ -170,11 +137,28 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -170,11 +137,28 @@ class DecodeWorkerHandler(BaseWorkerHandler):
rid=trace_id, rid=trace_id,
) )
# Wait for first token with timeout
decode_iter = decode.__aiter__()
try:
first_res = await asyncio.wait_for(
decode_iter.__anext__(), timeout=DECODE_KV_TRANSFER_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
raise RuntimeError(
f"Decode timed out after {DECODE_KV_TRANSFER_TIMEOUT_SECONDS}s waiting for first token. "
)
# Create stream starting with first result
async def decode_stream() -> AsyncGenerator[Dict[str, Any], None]:
yield first_res
async for res in decode_iter:
yield res
if self.skip_tokenizer_init: if self.skip_tokenizer_init:
async for out in self._process_token_stream(decode, context): async for out in self._process_token_stream(decode_stream(), context):
yield out yield out
else: else:
async for out in self._process_text_stream(decode, context): async for out in self._process_text_stream(decode_stream(), context):
yield out yield out
else: else:
if self.enable_trace: if self.enable_trace:
......
...@@ -57,7 +57,7 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -57,7 +57,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
"""Generate prefill output and provide bootstrap info for decode worker. """Generate prefill output and provide bootstrap info for decode worker.
Args: Args:
request: Request dict with 'request' and 'sampling_params' keys. request: Request dict with 'request', 'sampling_params', and possibly 'bootstrap_room' keys.
context: Context object for cancellation handling. context: Context object for cancellation handling.
Yields: Yields:
...@@ -65,7 +65,35 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -65,7 +65,35 @@ class PrefillWorkerHandler(BaseWorkerHandler):
""" """
logging.debug(f"New Request ID: {context.id()}") logging.debug(f"New Request ID: {context.id()}")
trace_id = context.trace_id trace_id = context.trace_id
bootstrap_room = self._generate_bootstrap_room()
if "request" in request:
# DisaggPreprocessedRequest format
inner_request = request["request"]
sampling_params = request.get("sampling_params", {})
else:
inner_request = request
sampling_opts = request.get("sampling_options", {})
stop_conditions = request.get("stop_conditions", {})
sampling_params = {
"temperature": sampling_opts.get("temperature"),
"top_p": sampling_opts.get("top_p"),
"top_k": sampling_opts.get("top_k"),
"max_new_tokens": stop_conditions.get("max_tokens"),
}
sampling_params = {
k: v for k, v in sampling_params.items() if v is not None
}
# Use provided bootstrap_room if available, otherwise generate one
bootstrap_room = None
extra_args = inner_request.get("extra_args", {})
if isinstance(extra_args, dict):
bootstrap_room = extra_args.get("bootstrap_room")
logging.debug(f"Using router-provided bootstrap_room: {bootstrap_room}")
if bootstrap_room is None:
bootstrap_room = self._generate_bootstrap_room()
logging.debug(f"Generated bootstrap_room locally: {bootstrap_room}")
bootstrap_info = { bootstrap_info = {
"bootstrap_host": self.bootstrap_host, "bootstrap_host": self.bootstrap_host,
...@@ -73,9 +101,16 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -73,9 +101,16 @@ class PrefillWorkerHandler(BaseWorkerHandler):
"bootstrap_room": bootstrap_room, "bootstrap_room": bootstrap_room,
} }
yield bootstrap_info # Yield in LLMEngineOutput format for PrefillRouter compatibility
# The disaggregated_params field contains the bootstrap info
yield {
"token_ids": [],
"text": None,
"finish_reason": None,
"disaggregated_params": bootstrap_info,
}
input_param = self._get_input_param(request["request"]) input_param = self._get_input_param(inner_request)
# Propagate trace context to SGLang # Propagate trace context to SGLang
if self.enable_trace: if self.enable_trace:
...@@ -83,7 +118,7 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -83,7 +118,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
results = await self.engine.async_generate( results = await self.engine.async_generate(
**input_param, **input_param,
sampling_params=request["sampling_params"], sampling_params=sampling_params,
stream=True, stream=True,
bootstrap_host=self.bootstrap_host, bootstrap_host=self.bootstrap_host,
bootstrap_port=self.bootstrap_port, bootstrap_port=self.bootstrap_port,
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
# Setup cleanup trap # Setup cleanup trap
cleanup() { cleanup() {
echo "Cleaning up background processes..." echo "Cleaning up background processes..."
kill $DYNAMO_PID $PREFILL_PID $PREFILL_ROUTER_PID 2>/dev/null || true kill $DYNAMO_PID $PREFILL_PID1 $PREFILL_PID2 $DECODE_PID1 2>/dev/null || true
wait $DYNAMO_PID $PREFILL_PID $PREFILL_ROUTER_PID 2>/dev/null || true wait $DYNAMO_PID $PREFILL_PID1 $PREFILL_PID2 $DECODE_PID1 2>/dev/null || true
echo "Cleanup complete." echo "Cleanup complete."
} }
trap cleanup EXIT INT TERM trap cleanup EXIT INT TERM
...@@ -26,7 +26,7 @@ while [[ $# -gt 0 ]]; do ...@@ -26,7 +26,7 @@ while [[ $# -gt 0 ]]; do
echo " -h, --help Show this help message" echo " -h, --help Show this help message"
echo "" echo ""
echo "Note: System metrics are enabled by default on ports:" echo "Note: System metrics are enabled by default on ports:"
echo " 8081 (router), 8082-8083 (prefill workers), 8084-8085 (decode workers)" echo " 8082-8083 (prefill workers), 8084-8085 (decode workers)"
exit 0 exit 0
;; ;;
*) *)
...@@ -46,13 +46,13 @@ if [ "$ENABLE_OTEL" = true ]; then ...@@ -46,13 +46,13 @@ if [ "$ENABLE_OTEL" = true ]; then
TRACE_ARGS+=(--enable-trace --otlp-traces-endpoint localhost:4317) TRACE_ARGS+=(--enable-trace --otlp-traces-endpoint localhost:4317)
fi fi
# run ingress # Start frontend with KV routing
# The frontend will automatically detect prefill workers and activate an internal prefill router
# dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000) # dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000)
OTEL_SERVICE_NAME=dynamo-frontend \ OTEL_SERVICE_NAME=dynamo-frontend \
python3 -m dynamo.frontend \ python3 -m dynamo.frontend \
--router-mode kv \ --router-mode kv \
--kv-overlap-score-weight 0 \ --router-reset-states &
--router-reset-states &
DYNAMO_PID=$! DYNAMO_PID=$!
# run prefill router # run prefill router
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::*; use super::*;
use llm_rs::local_model::runtime_config::DisaggregatedEndpoint as RsDisaggregatedEndpoint;
use llm_rs::local_model::runtime_config::ModelRuntimeConfig as RsModelRuntimeConfig; use llm_rs::local_model::runtime_config::ModelRuntimeConfig as RsModelRuntimeConfig;
#[pyclass] #[pyclass]
...@@ -125,4 +126,32 @@ impl ModelRuntimeConfig { ...@@ -125,4 +126,32 @@ impl ModelRuntimeConfig {
fn get_engine_specific(&self, key: &str) -> PyResult<Option<String>> { fn get_engine_specific(&self, key: &str) -> PyResult<Option<String>> {
self.inner.get_engine_specific(key).map_err(to_pyerr) self.inner.get_engine_specific(key).map_err(to_pyerr)
} }
#[pyo3(signature = (bootstrap_host=None, bootstrap_port=None))]
fn set_disaggregated_endpoint(
&mut self,
bootstrap_host: Option<String>,
bootstrap_port: Option<u16>,
) {
self.inner.disaggregated_endpoint = Some(RsDisaggregatedEndpoint {
bootstrap_host,
bootstrap_port,
});
}
#[getter]
fn bootstrap_host(&self) -> Option<String> {
self.inner
.disaggregated_endpoint
.as_ref()
.and_then(|e| e.bootstrap_host.clone())
}
#[getter]
fn bootstrap_port(&self) -> Option<u16> {
self.inner
.disaggregated_endpoint
.as_ref()
.and_then(|e| e.bootstrap_port)
}
} }
...@@ -562,6 +562,15 @@ impl KvRouter { ...@@ -562,6 +562,15 @@ impl KvRouter {
self.block_size self.block_size
} }
/// Get the disaggregated endpoint for a worker, if available.
/// Used to look up bootstrap host/port for prefill workers.
pub async fn get_disaggregated_endpoint(
&self,
worker_id: u64,
) -> Option<crate::local_model::runtime_config::DisaggregatedEndpoint> {
self.scheduler.get_disaggregated_endpoint(worker_id).await
}
/// Get potential prefill and decode loads for all workers /// Get potential prefill and decode loads for all workers
pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> { pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
......
...@@ -5,6 +5,7 @@ use std::sync::{Arc, OnceLock}; ...@@ -5,6 +5,7 @@ use std::sync::{Arc, OnceLock};
use anyhow::Result; use anyhow::Result;
use futures::StreamExt; use futures::StreamExt;
use rand::Rng;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
...@@ -21,7 +22,7 @@ use crate::{ ...@@ -21,7 +22,7 @@ use crate::{
discovery::ModelManager, discovery::ModelManager,
kv_router::{KvPushRouter, KvRouterConfig, RouterConfigOverride}, kv_router::{KvPushRouter, KvRouterConfig, RouterConfigOverride},
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest}, protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
protocols::common::preprocessor::PrefillResult, protocols::common::preprocessor::{BootstrapInfo, PrefillResult},
}; };
/// Errors that can occur during prefill routing /// Errors that can occur during prefill routing
...@@ -42,6 +43,7 @@ pub enum PrefillError { ...@@ -42,6 +43,7 @@ pub enum PrefillError {
} }
/// The inner router used by PrefillRouter /// The inner router used by PrefillRouter
#[derive(Clone)]
enum InnerPrefillRouter { enum InnerPrefillRouter {
/// KV-aware routing using KvPushRouter /// KV-aware routing using KvPushRouter
KvRouter(Arc<KvPushRouter>), KvRouter(Arc<KvPushRouter>),
...@@ -49,6 +51,19 @@ enum InnerPrefillRouter { ...@@ -49,6 +51,19 @@ enum InnerPrefillRouter {
SimpleRouter(Arc<PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>>), SimpleRouter(Arc<PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>>),
} }
impl InnerPrefillRouter {
/// Execute prefill generation through the underlying router
async fn generate(
&self,
request: SingleIn<PreprocessedRequest>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
match self {
InnerPrefillRouter::KvRouter(router) => router.generate(request).await,
InnerPrefillRouter::SimpleRouter(router) => router.generate(request).await,
}
}
}
/// PrefillRouter is a forward-only operator that sits between Migration and the decode router. /// PrefillRouter is a forward-only operator that sits between Migration and the decode router.
/// It optionally calls a prefill worker before routing to decode, extracting disaggregated_params /// It optionally calls a prefill worker before routing to decode, extracting disaggregated_params
/// from the prefill response and injecting them into the decode request. /// from the prefill response and injecting them into the decode request.
...@@ -176,28 +191,75 @@ impl PrefillRouter { ...@@ -176,28 +191,75 @@ impl PrefillRouter {
Ok(()) Ok(())
} }
/// Call the prefill router and extract structured prefill result /// Generate a unique bootstrap room ID for disaggregated serving
async fn call_prefill( fn generate_bootstrap_room() -> u64 {
rand::rng().random()
}
/// Query best worker upfront, build bootstrap_info, and spawn prefill in background
async fn build_bootstrap_info(
&self, &self,
request: SingleIn<PreprocessedRequest>, req: &PreprocessedRequest,
) -> Result<PrefillResult, PrefillError> { ) -> Option<(u64, u32, BootstrapInfo)> {
// Get the prefill router, error if not activated let prefill_router = self.prefill_router.get()?;
let Some(prefill_router) = self.prefill_router.get() else {
return Err(PrefillError::NotActivated); // Only works with KvRouter
let kv_router = match prefill_router {
InnerPrefillRouter::KvRouter(r) => r,
InnerPrefillRouter::SimpleRouter(_) => return None,
}; };
// Call the appropriate router based on the type // Query best worker without routing
let mut prefill_response = match prefill_router { let (worker_id, dp_rank) = match kv_router
InnerPrefillRouter::KvRouter(router) => router .chooser
.generate(request) .find_best_match(None, &req.token_ids, None, false)
.await .await
.map_err(|e| PrefillError::PrefillError(e.to_string()))?, {
InnerPrefillRouter::SimpleRouter(router) => router Ok((worker, _overlap)) => (worker.worker_id, worker.dp_rank),
.generate(request) Err(_) => return None,
.await
.map_err(|e| PrefillError::PrefillError(e.to_string()))?,
}; };
// Look up bootstrap endpoint from discovery
let endpoint = kv_router
.chooser
.get_disaggregated_endpoint(worker_id)
.await?;
let host = endpoint.bootstrap_host?;
let port = endpoint.bootstrap_port?;
let bootstrap_room = Self::generate_bootstrap_room();
tracing::info!(
worker_id = worker_id,
dp_rank = dp_rank,
bootstrap_host = %host,
bootstrap_port = port,
bootstrap_room = bootstrap_room,
"Built bootstrap_info upfront before prefill"
);
Some((
worker_id,
dp_rank,
BootstrapInfo {
bootstrap_host: host,
bootstrap_port: port,
bootstrap_room,
},
))
}
/// Execute prefill with the given router and extract structured result
async fn execute_prefill(
router: Option<InnerPrefillRouter>,
request: SingleIn<PreprocessedRequest>,
) -> Result<(PrefillResult, Option<u64>), PrefillError> {
let router = router.ok_or(PrefillError::NotActivated)?;
let mut prefill_response = router
.generate(request)
.await
.map_err(|e| PrefillError::PrefillError(e.to_string()))?;
let Some(first_output) = prefill_response.next().await else { let Some(first_output) = prefill_response.next().await else {
return Err(PrefillError::PrefillError( return Err(PrefillError::PrefillError(
"Prefill router returned no output (stream ended)".to_string(), "Prefill router returned no output (stream ended)".to_string(),
...@@ -239,10 +301,45 @@ impl PrefillRouter { ...@@ -239,10 +301,45 @@ impl PrefillRouter {
)); ));
}; };
Ok(PrefillResult { // Extract prefill worker ID from disaggregated_params
disaggregated_params, let prefill_worker_id = disaggregated_params
prompt_tokens_details, .get("worker_id")
}) .and_then(|worker_id_json| {
worker_id_json
.get("prefill_worker_id")
.and_then(|v| v.as_u64())
});
Ok((
PrefillResult {
disaggregated_params,
prompt_tokens_details,
},
prefill_worker_id,
))
}
/// Spawn prefill as a background task
fn spawn_prefill_task(&self, prefill_request: SingleIn<PreprocessedRequest>) {
let router = self.prefill_router.get().cloned();
tokio::spawn(async move {
match Self::execute_prefill(router, prefill_request).await {
Ok(_) => {
tracing::debug!("Prefill background task completed");
}
Err(e) => {
tracing::warn!("Prefill background task error: {e:?}");
}
}
});
}
/// Call the prefill router and extract structured prefill result and worker ID
async fn call_prefill(
&self,
request: SingleIn<PreprocessedRequest>,
) -> Result<(PrefillResult, Option<u64>), PrefillError> {
Self::execute_prefill(self.prefill_router.get().cloned(), request).await
} }
} }
...@@ -278,15 +375,43 @@ impl ...@@ -278,15 +375,43 @@ impl
// Prepare prefill request with max_tokens = 1 // Prepare prefill request with max_tokens = 1
let mut prefill_req = req.clone(); let mut prefill_req = req.clone();
prefill_req.stop_conditions.max_tokens = Some(1); prefill_req.stop_conditions.max_tokens = Some(1);
let prefill_context = Context::with_id(prefill_req, request_id.clone());
// Link the prefill context as a child so that kill signals propagate // Try build_bootstrap_info optimization
engine_ctx.link_child(prefill_context.context()); let prefill_result = if let Some((worker_id, dp_rank, bootstrap_info)) =
self.build_bootstrap_info(&prefill_req).await
{
let bootstrap_room = bootstrap_info.bootstrap_room;
// Prepare request with bootstrap_room and force routing to specific worker
prefill_req.backend_instance_id = Some(worker_id);
prefill_req.dp_rank = Some(dp_rank);
let extra_args = prefill_req
.extra_args
.get_or_insert_with(|| serde_json::json!({}));
if let Some(obj) = extra_args.as_object_mut() {
obj.insert(
"bootstrap_room".to_string(),
serde_json::json!(bootstrap_room),
);
}
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
self.spawn_prefill_task(prefill_context);
let prefill_request = prefill_context; Ok((None, Some(worker_id), Some(bootstrap_info)))
} else {
// Fallback to original: Wait for prefill to complete
tracing::debug!("Using original prefill path");
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
// Attempt prefill self.call_prefill(prefill_context)
let prefill_result = self.call_prefill(prefill_request).await; .await
.map(|(result, worker_id)| (Some(result), worker_id, None))
};
// Abort if cancelled during prefill // Abort if cancelled during prefill
if engine_ctx.is_stopped() || engine_ctx.is_killed() { if engine_ctx.is_stopped() || engine_ctx.is_killed() {
...@@ -299,15 +424,24 @@ impl ...@@ -299,15 +424,24 @@ impl
// Handle prefill result // Handle prefill result
match prefill_result { match prefill_result {
Ok(prefill_result) => { Ok((maybe_prefill_result, _prefill_worker_id, bootstrap_info)) => {
tracing::debug!("Prefill succeeded, using disaggregated params for decode"); tracing::debug!("Prefill completed, proceeding to decode");
let mut decode_req = req; let mut decode_req = req;
// Update request with prefill result
decode_req.prefill_result = Some(prefill_result.clone()); // Update request with prefill result if available (only in original path)
if let Some(prefill_result) = maybe_prefill_result {
decode_req.prefill_result = Some(prefill_result);
}
// Restore original max_tokens for decode // Restore original max_tokens for decode
decode_req.stop_conditions.max_tokens = original_max_tokens; decode_req.stop_conditions.max_tokens = original_max_tokens;
// Inject bootstrap_info for decode worker
if let Some(info) = bootstrap_info {
decode_req.bootstrap_info = Some(info);
}
// Set router_config_override for decode: overlap_score_weight = 0 // Set router_config_override for decode: overlap_score_weight = 0
let existing_override = decode_req.router_config_override.take(); let existing_override = decode_req.router_config_override.take();
decode_req.router_config_override = Some(RouterConfigOverride { decode_req.router_config_override = Some(RouterConfigOverride {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use crate::local_model::runtime_config::ModelRuntimeConfig; use crate::local_model::runtime_config::{DisaggregatedEndpoint, ModelRuntimeConfig};
use anyhow::Result; use anyhow::Result;
use dynamo_runtime::component::Component; use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
...@@ -90,6 +90,8 @@ impl SchedulingRequest { ...@@ -90,6 +90,8 @@ impl SchedulingRequest {
pub struct KvScheduler { pub struct KvScheduler {
request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>, request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
slots: Arc<ActiveSequencesMultiWorker>, slots: Arc<ActiveSequencesMultiWorker>,
/// Worker runtime configs for looking up disaggregated endpoints
workers_with_configs: Arc<RwLock<HashMap<WorkerId, Option<ModelRuntimeConfig>>>>,
} }
impl KvScheduler { impl KvScheduler {
...@@ -287,7 +289,11 @@ impl KvScheduler { ...@@ -287,7 +289,11 @@ impl KvScheduler {
tracing::trace!("background endpoint subscriber shutting down"); tracing::trace!("background endpoint subscriber shutting down");
}); });
Ok(KvScheduler { request_tx, slots }) Ok(KvScheduler {
request_tx,
slots,
workers_with_configs,
})
} }
pub async fn schedule( pub async fn schedule(
...@@ -346,6 +352,17 @@ impl KvScheduler { ...@@ -346,6 +352,17 @@ impl KvScheduler {
self.slots.free(&request_id.to_string()).await self.slots.free(&request_id.to_string()).await
} }
pub async fn get_disaggregated_endpoint(
&self,
worker_id: WorkerId,
) -> Option<DisaggregatedEndpoint> {
let workers = self.workers_with_configs.read().await;
workers
.get(&worker_id)
.and_then(|config| config.as_ref())
.and_then(|config| config.disaggregated_endpoint.clone())
}
pub async fn get_potential_loads( pub async fn get_potential_loads(
&self, &self,
token_seq: Option<Vec<SequenceHash>>, token_seq: Option<Vec<SequenceHash>>,
......
...@@ -7,6 +7,15 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned}; ...@@ -7,6 +7,15 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned};
use crate::protocols::tensor; use crate::protocols::tensor;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
pub struct DisaggregatedEndpoint {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub bootstrap_host: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub bootstrap_port: Option<u16>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ModelRuntimeConfig { pub struct ModelRuntimeConfig {
pub total_kv_blocks: Option<u64>, pub total_kv_blocks: Option<u64>,
...@@ -40,6 +49,10 @@ pub struct ModelRuntimeConfig { ...@@ -40,6 +49,10 @@ pub struct ModelRuntimeConfig {
// doesn't provide JSON parsing. // doesn't provide JSON parsing.
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub tensor_model_config: Option<tensor::TensorModelConfig>, pub tensor_model_config: Option<tensor::TensorModelConfig>,
/// Bootstrap endpoint for disaggregated serving (prefill workers publish this)
#[serde(default, skip_serializing_if = "Option::is_none")]
pub disaggregated_endpoint: Option<DisaggregatedEndpoint>,
} }
const fn default_data_parallel_size() -> u32 { const fn default_data_parallel_size() -> u32 {
...@@ -58,6 +71,7 @@ impl Default for ModelRuntimeConfig { ...@@ -58,6 +71,7 @@ impl Default for ModelRuntimeConfig {
enable_local_indexer: false, enable_local_indexer: false,
runtime_data: HashMap::new(), runtime_data: HashMap::new(),
tensor_model_config: None, tensor_model_config: None,
disaggregated_endpoint: None,
} }
} }
} }
......
...@@ -10,6 +10,18 @@ use crate::kv_router::RouterConfigOverride; ...@@ -10,6 +10,18 @@ use crate::kv_router::RouterConfigOverride;
use crate::preprocessor::media::RdmaMediaDataDescriptor; use crate::preprocessor::media::RdmaMediaDataDescriptor;
use crate::protocols::TokenIdType; use crate::protocols::TokenIdType;
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct BootstrapInfo {
/// The host address for bootstrap connection
pub bootstrap_host: String,
/// The port for bootstrap connection
pub bootstrap_port: u16,
/// Unique room ID for this request's KV transfer session
pub bootstrap_room: u64,
}
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PrefillResult { pub struct PrefillResult {
/// Disaggregated execution parameters /// Disaggregated execution parameters
...@@ -87,6 +99,11 @@ pub struct PreprocessedRequest { ...@@ -87,6 +99,11 @@ pub struct PreprocessedRequest {
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub prefill_result: Option<PrefillResult>, pub prefill_result: Option<PrefillResult>,
/// Bootstrap info for disaggregated serving
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub bootstrap_info: Option<BootstrapInfo>,
/// Data parallel rank for the request (used with data parallelism) /// Data parallel rank for the request (used with data parallelism)
#[builder(default)] #[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment