Unverified Commit 1c39e483 authored by Sean SH Choi's avatar Sean SH Choi Committed by GitHub
Browse files

feat: add frontend based prefill request routing for sglang (#4635)


Signed-off-by: default avatarSean Choi <sechoi@nvidia.com>
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarYan Ru Pei <yanrpei@gmail.com>
Co-authored-by: default avatarishandhanani <82981111+ishandhanani@users.noreply.github.com>
parent 7a5f70f3
......@@ -137,22 +137,6 @@ async def init(runtime: DistributedRuntime, config: Config):
"Registered engine routes: /engine/start_profile, /engine/stop_profile"
)
prefill_client = None
prefill_router_client = None
if config.serving_mode == DisaggregationMode.DECODE:
prefill_router_client = (
await runtime.namespace(dynamo_args.namespace)
.component("router")
.endpoint("best_worker_id")
.client()
)
prefill_client = (
await runtime.namespace(dynamo_args.namespace)
.component("prefill")
.endpoint("generate")
.client()
)
# publisher instantiates the metrics and kv event publishers
publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint
......@@ -165,9 +149,7 @@ async def init(runtime: DistributedRuntime, config: Config):
# Readiness gate: requests wait until model is registered
ready_event = asyncio.Event()
handler = DecodeWorkerHandler(
component, engine, config, publisher, prefill_client, prefill_router_client
)
handler = DecodeWorkerHandler(component, engine, config, publisher)
print(f"Config: {config}")
health_check_payload = SglangHealthCheckPayload(
engine, use_text_input=dynamo_args.use_sglang_tokenizer
......@@ -272,17 +254,29 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict()
tasks = [
# Readiness gate: requests wait until model is registered
ready_event = asyncio.Event()
try:
# Start endpoint immediately and register model concurrently
# Registration publishes runtime_config with bootstrap endpoint for optimization
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=metrics_labels,
health_check_payload=health_check_payload,
),
register_llm_with_readiness_gate(
engine,
generate_endpoint,
server_args,
dynamo_args,
input_type=ModelInput.Tokens,
output_type=ModelType.Prefill,
readiness_gate=ready_event,
),
)
]
try:
await asyncio.gather(*tasks)
except Exception as e:
logging.error(f"Failed to serve endpoints: {e}")
raise
......
......@@ -3,10 +3,12 @@
import asyncio
import logging
import socket
from typing import Optional
import sglang as sgl
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_local_ip_auto
from dynamo._core import Endpoint
from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm
......@@ -65,6 +67,39 @@ async def _register_llm_with_runtime_config(
return False
def _get_bootstrap_info_for_config(
engine: sgl.Engine,
) -> tuple[Optional[str], Optional[int]]:
"""Extract bootstrap host and port from SGLang engine for config registration.
Args:
engine: The SGLang engine instance.
Returns:
Tuple of (bootstrap_host, bootstrap_port), or (None, None) if not available.
"""
try:
inner_tm = engine.tokenizer_manager
bootstrap_port = getattr(
inner_tm.server_args, "disaggregation_bootstrap_port", None
)
if bootstrap_port is None:
return None, None
if inner_tm.server_args.dist_init_addr:
bootstrap_host = socket.gethostbyname(
inner_tm.server_args.dist_init_addr.split(":")[0]
)
else:
bootstrap_host = get_local_ip_auto()
return bootstrap_host, bootstrap_port
except Exception as e:
logging.warning(f"Failed to get bootstrap info: {e}")
return None, None
async def _get_runtime_config(
engine: sgl.Engine, server_args: ServerArgs, dynamo_args: DynamoArgs
) -> Optional[ModelRuntimeConfig]:
......@@ -84,6 +119,14 @@ async def _get_runtime_config(
runtime_config.tool_call_parser = dynamo_args.tool_call_parser
runtime_config.enable_local_indexer = dynamo_args.enable_local_indexer
# Set bootstrap endpoint for disaggregated serving (prefill workers)
bootstrap_host, bootstrap_port = _get_bootstrap_info_for_config(engine)
if bootstrap_host and bootstrap_port:
runtime_config.set_disaggregated_endpoint(bootstrap_host, bootstrap_port)
logging.info(
f"Publishing disaggregated endpoint to discovery: "
f"{bootstrap_host}:{bootstrap_port}"
)
# In SGLang, these are server_args, not scheduler_info (unlike vLLM)
# Note: If --max-running-requests is not specified, SGLang uses an internal default
# undocumented value. The value here will be None if not explicitly set by user.
......
......@@ -4,16 +4,18 @@
import asyncio
import logging
import time
from typing import Any, AsyncGenerator, Dict, Optional
from typing import Any, AsyncGenerator, Dict
import sglang as sgl
from dynamo._core import Client, Component, Context
from dynamo._core import Component, Context
from dynamo.sglang.args import Config, DisaggregationMode
from dynamo.sglang.protocol import DisaggPreprocessedRequest
from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
# Timeout for decode engine to receive first response when waiting for KV cache transfer
DECODE_KV_TRANSFER_TIMEOUT_SECONDS = 60.0
class DecodeWorkerHandler(BaseWorkerHandler):
"""Handler for decode workers in both aggregated and disaggregated serving modes."""
......@@ -24,8 +26,6 @@ class DecodeWorkerHandler(BaseWorkerHandler):
engine: sgl.Engine,
config: Config,
publisher: DynamoSglangPublisher,
prefill_client: Optional[Client] = None,
prefill_router_client: Optional[Client] = None,
) -> None:
"""Initialize decode worker handler.
......@@ -34,29 +34,19 @@ class DecodeWorkerHandler(BaseWorkerHandler):
engine: The SGLang engine instance.
config: SGLang and Dynamo configuration.
publisher: Metrics publisher for the worker.
prefill_client: Optional client for prefill worker in disaggregated mode.
prefill_router_client: Optional client for prefill router in disaggregated mode.
Raises:
ValueError: If prefill_client is not provided in decode serving mode.
"""
super().__init__(
component,
engine,
config,
publisher,
prefill_client,
)
if self.serving_mode == DisaggregationMode.DECODE:
if self.prefill_client is None:
raise ValueError(
"prefill_client must be provided when serving_mode is decode"
logging.info(
"Decode worker handler initialized (disaggregated decode mode)"
)
self.prefill_client = prefill_client
logging.info("Decode worker handler initialized")
self.prefill_router_client = prefill_router_client
logging.info("Worker handler initialized")
else:
logging.info("Decode worker handler initialized (aggregated mode)")
def cleanup(self) -> None:
"""Shutdown the engine and cleanup resources."""
......@@ -117,43 +107,20 @@ class DecodeWorkerHandler(BaseWorkerHandler):
input_param = self._get_input_param(request)
if self.serving_mode == DisaggregationMode.DECODE:
# request the bootstrap info from the target prefill worker
if (
self.prefill_router_client is not None
and self.prefill_router_client.instance_ids()
):
token_ids = request["token_ids"]
stream = await self.prefill_router_client.generate(token_ids)
result = await anext(stream)
(
worker_id,
overlap,
) = result.data() # Returns tuple (worker_id, overlap_amount)
logging.info(f"Best prefill worker ID: {worker_id}, overlap: {overlap}")
prefill_stream = await self.prefill_client.direct(
DisaggPreprocessedRequest(
request=request,
sampling_params=sampling_params,
).model_dump(),
worker_id,
)
else:
prefill_stream = await self.prefill_client.generate(
DisaggPreprocessedRequest(
request=request,
sampling_params=sampling_params,
).model_dump(),
context=context,
)
bootstrap_info = None
async for info in prefill_stream:
bootstrap_info = info.data()
break
# Check if bootstrap_info is in the request
bootstrap_info = request.get("bootstrap_info")
if not bootstrap_info:
raise RuntimeError("No bootstrap info received from prefill worker")
raise RuntimeError(
"bootstrap_info is required for disaggregated decode but was not provided."
)
logging.debug(
f"Using bootstrap_info: "
f"host={bootstrap_info['bootstrap_host']}, "
f"port={bootstrap_info['bootstrap_port']}, "
f"room={bootstrap_info['bootstrap_room']}"
)
if self.enable_trace:
self._propagate_trace_context_to_sglang(
......@@ -170,11 +137,28 @@ class DecodeWorkerHandler(BaseWorkerHandler):
rid=trace_id,
)
# Wait for first token with timeout
decode_iter = decode.__aiter__()
try:
first_res = await asyncio.wait_for(
decode_iter.__anext__(), timeout=DECODE_KV_TRANSFER_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
raise RuntimeError(
f"Decode timed out after {DECODE_KV_TRANSFER_TIMEOUT_SECONDS}s waiting for first token. "
)
# Create stream starting with first result
async def decode_stream() -> AsyncGenerator[Dict[str, Any], None]:
yield first_res
async for res in decode_iter:
yield res
if self.skip_tokenizer_init:
async for out in self._process_token_stream(decode, context):
async for out in self._process_token_stream(decode_stream(), context):
yield out
else:
async for out in self._process_text_stream(decode, context):
async for out in self._process_text_stream(decode_stream(), context):
yield out
else:
if self.enable_trace:
......
......@@ -57,7 +57,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
"""Generate prefill output and provide bootstrap info for decode worker.
Args:
request: Request dict with 'request' and 'sampling_params' keys.
request: Request dict with 'request', 'sampling_params', and possibly 'bootstrap_room' keys.
context: Context object for cancellation handling.
Yields:
......@@ -65,7 +65,35 @@ class PrefillWorkerHandler(BaseWorkerHandler):
"""
logging.debug(f"New Request ID: {context.id()}")
trace_id = context.trace_id
if "request" in request:
# DisaggPreprocessedRequest format
inner_request = request["request"]
sampling_params = request.get("sampling_params", {})
else:
inner_request = request
sampling_opts = request.get("sampling_options", {})
stop_conditions = request.get("stop_conditions", {})
sampling_params = {
"temperature": sampling_opts.get("temperature"),
"top_p": sampling_opts.get("top_p"),
"top_k": sampling_opts.get("top_k"),
"max_new_tokens": stop_conditions.get("max_tokens"),
}
sampling_params = {
k: v for k, v in sampling_params.items() if v is not None
}
# Use provided bootstrap_room if available, otherwise generate one
bootstrap_room = None
extra_args = inner_request.get("extra_args", {})
if isinstance(extra_args, dict):
bootstrap_room = extra_args.get("bootstrap_room")
logging.debug(f"Using router-provided bootstrap_room: {bootstrap_room}")
if bootstrap_room is None:
bootstrap_room = self._generate_bootstrap_room()
logging.debug(f"Generated bootstrap_room locally: {bootstrap_room}")
bootstrap_info = {
"bootstrap_host": self.bootstrap_host,
......@@ -73,9 +101,16 @@ class PrefillWorkerHandler(BaseWorkerHandler):
"bootstrap_room": bootstrap_room,
}
yield bootstrap_info
# Yield in LLMEngineOutput format for PrefillRouter compatibility
# The disaggregated_params field contains the bootstrap info
yield {
"token_ids": [],
"text": None,
"finish_reason": None,
"disaggregated_params": bootstrap_info,
}
input_param = self._get_input_param(request["request"])
input_param = self._get_input_param(inner_request)
# Propagate trace context to SGLang
if self.enable_trace:
......@@ -83,7 +118,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
results = await self.engine.async_generate(
**input_param,
sampling_params=request["sampling_params"],
sampling_params=sampling_params,
stream=True,
bootstrap_host=self.bootstrap_host,
bootstrap_port=self.bootstrap_port,
......
......@@ -5,8 +5,8 @@
# Setup cleanup trap
cleanup() {
echo "Cleaning up background processes..."
kill $DYNAMO_PID $PREFILL_PID $PREFILL_ROUTER_PID 2>/dev/null || true
wait $DYNAMO_PID $PREFILL_PID $PREFILL_ROUTER_PID 2>/dev/null || true
kill $DYNAMO_PID $PREFILL_PID1 $PREFILL_PID2 $DECODE_PID1 2>/dev/null || true
wait $DYNAMO_PID $PREFILL_PID1 $PREFILL_PID2 $DECODE_PID1 2>/dev/null || true
echo "Cleanup complete."
}
trap cleanup EXIT INT TERM
......@@ -26,7 +26,7 @@ while [[ $# -gt 0 ]]; do
echo " -h, --help Show this help message"
echo ""
echo "Note: System metrics are enabled by default on ports:"
echo " 8081 (router), 8082-8083 (prefill workers), 8084-8085 (decode workers)"
echo " 8082-8083 (prefill workers), 8084-8085 (decode workers)"
exit 0
;;
*)
......@@ -46,12 +46,12 @@ if [ "$ENABLE_OTEL" = true ]; then
TRACE_ARGS+=(--enable-trace --otlp-traces-endpoint localhost:4317)
fi
# run ingress
# Start frontend with KV routing
# The frontend will automatically detect prefill workers and activate an internal prefill router
# dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000)
OTEL_SERVICE_NAME=dynamo-frontend \
python3 -m dynamo.frontend \
--router-mode kv \
--kv-overlap-score-weight 0 \
--router-reset-states &
DYNAMO_PID=$!
......
......@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use super::*;
use llm_rs::local_model::runtime_config::DisaggregatedEndpoint as RsDisaggregatedEndpoint;
use llm_rs::local_model::runtime_config::ModelRuntimeConfig as RsModelRuntimeConfig;
#[pyclass]
......@@ -125,4 +126,32 @@ impl ModelRuntimeConfig {
fn get_engine_specific(&self, key: &str) -> PyResult<Option<String>> {
self.inner.get_engine_specific(key).map_err(to_pyerr)
}
#[pyo3(signature = (bootstrap_host=None, bootstrap_port=None))]
fn set_disaggregated_endpoint(
&mut self,
bootstrap_host: Option<String>,
bootstrap_port: Option<u16>,
) {
self.inner.disaggregated_endpoint = Some(RsDisaggregatedEndpoint {
bootstrap_host,
bootstrap_port,
});
}
#[getter]
fn bootstrap_host(&self) -> Option<String> {
self.inner
.disaggregated_endpoint
.as_ref()
.and_then(|e| e.bootstrap_host.clone())
}
#[getter]
fn bootstrap_port(&self) -> Option<u16> {
self.inner
.disaggregated_endpoint
.as_ref()
.and_then(|e| e.bootstrap_port)
}
}
......@@ -562,6 +562,15 @@ impl KvRouter {
self.block_size
}
/// Get the disaggregated endpoint for a worker, if available.
/// Used to look up bootstrap host/port for prefill workers.
pub async fn get_disaggregated_endpoint(
&self,
worker_id: u64,
) -> Option<crate::local_model::runtime_config::DisaggregatedEndpoint> {
self.scheduler.get_disaggregated_endpoint(worker_id).await
}
/// Get potential prefill and decode loads for all workers
pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
let isl_tokens = tokens.len();
......
......@@ -5,6 +5,7 @@ use std::sync::{Arc, OnceLock};
use anyhow::Result;
use futures::StreamExt;
use rand::Rng;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
......@@ -21,7 +22,7 @@ use crate::{
discovery::ModelManager,
kv_router::{KvPushRouter, KvRouterConfig, RouterConfigOverride},
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
protocols::common::preprocessor::PrefillResult,
protocols::common::preprocessor::{BootstrapInfo, PrefillResult},
};
/// Errors that can occur during prefill routing
......@@ -42,6 +43,7 @@ pub enum PrefillError {
}
/// The inner router used by PrefillRouter
#[derive(Clone)]
enum InnerPrefillRouter {
/// KV-aware routing using KvPushRouter
KvRouter(Arc<KvPushRouter>),
......@@ -49,6 +51,19 @@ enum InnerPrefillRouter {
SimpleRouter(Arc<PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>>),
}
impl InnerPrefillRouter {
/// Execute prefill generation through the underlying router
async fn generate(
&self,
request: SingleIn<PreprocessedRequest>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
match self {
InnerPrefillRouter::KvRouter(router) => router.generate(request).await,
InnerPrefillRouter::SimpleRouter(router) => router.generate(request).await,
}
}
}
/// PrefillRouter is a forward-only operator that sits between Migration and the decode router.
/// It optionally calls a prefill worker before routing to decode, extracting disaggregated_params
/// from the prefill response and injecting them into the decode request.
......@@ -176,27 +191,74 @@ impl PrefillRouter {
Ok(())
}
/// Call the prefill router and extract structured prefill result
async fn call_prefill(
/// Generate a unique bootstrap room ID for disaggregated serving
fn generate_bootstrap_room() -> u64 {
rand::rng().random()
}
/// Query best worker upfront, build bootstrap_info, and spawn prefill in background
async fn build_bootstrap_info(
&self,
request: SingleIn<PreprocessedRequest>,
) -> Result<PrefillResult, PrefillError> {
// Get the prefill router, error if not activated
let Some(prefill_router) = self.prefill_router.get() else {
return Err(PrefillError::NotActivated);
req: &PreprocessedRequest,
) -> Option<(u64, u32, BootstrapInfo)> {
let prefill_router = self.prefill_router.get()?;
// Only works with KvRouter
let kv_router = match prefill_router {
InnerPrefillRouter::KvRouter(r) => r,
InnerPrefillRouter::SimpleRouter(_) => return None,
};
// Call the appropriate router based on the type
let mut prefill_response = match prefill_router {
InnerPrefillRouter::KvRouter(router) => router
.generate(request)
// Query best worker without routing
let (worker_id, dp_rank) = match kv_router
.chooser
.find_best_match(None, &req.token_ids, None, false)
.await
.map_err(|e| PrefillError::PrefillError(e.to_string()))?,
InnerPrefillRouter::SimpleRouter(router) => router
{
Ok((worker, _overlap)) => (worker.worker_id, worker.dp_rank),
Err(_) => return None,
};
// Look up bootstrap endpoint from discovery
let endpoint = kv_router
.chooser
.get_disaggregated_endpoint(worker_id)
.await?;
let host = endpoint.bootstrap_host?;
let port = endpoint.bootstrap_port?;
let bootstrap_room = Self::generate_bootstrap_room();
tracing::info!(
worker_id = worker_id,
dp_rank = dp_rank,
bootstrap_host = %host,
bootstrap_port = port,
bootstrap_room = bootstrap_room,
"Built bootstrap_info upfront before prefill"
);
Some((
worker_id,
dp_rank,
BootstrapInfo {
bootstrap_host: host,
bootstrap_port: port,
bootstrap_room,
},
))
}
/// Execute prefill with the given router and extract structured result
async fn execute_prefill(
router: Option<InnerPrefillRouter>,
request: SingleIn<PreprocessedRequest>,
) -> Result<(PrefillResult, Option<u64>), PrefillError> {
let router = router.ok_or(PrefillError::NotActivated)?;
let mut prefill_response = router
.generate(request)
.await
.map_err(|e| PrefillError::PrefillError(e.to_string()))?,
};
.map_err(|e| PrefillError::PrefillError(e.to_string()))?;
let Some(first_output) = prefill_response.next().await else {
return Err(PrefillError::PrefillError(
......@@ -239,10 +301,45 @@ impl PrefillRouter {
));
};
Ok(PrefillResult {
// Extract prefill worker ID from disaggregated_params
let prefill_worker_id = disaggregated_params
.get("worker_id")
.and_then(|worker_id_json| {
worker_id_json
.get("prefill_worker_id")
.and_then(|v| v.as_u64())
});
Ok((
PrefillResult {
disaggregated_params,
prompt_tokens_details,
})
},
prefill_worker_id,
))
}
/// Spawn prefill as a background task
fn spawn_prefill_task(&self, prefill_request: SingleIn<PreprocessedRequest>) {
let router = self.prefill_router.get().cloned();
tokio::spawn(async move {
match Self::execute_prefill(router, prefill_request).await {
Ok(_) => {
tracing::debug!("Prefill background task completed");
}
Err(e) => {
tracing::warn!("Prefill background task error: {e:?}");
}
}
});
}
/// Call the prefill router and extract structured prefill result and worker ID
async fn call_prefill(
&self,
request: SingleIn<PreprocessedRequest>,
) -> Result<(PrefillResult, Option<u64>), PrefillError> {
Self::execute_prefill(self.prefill_router.get().cloned(), request).await
}
}
......@@ -278,15 +375,43 @@ impl
// Prepare prefill request with max_tokens = 1
let mut prefill_req = req.clone();
prefill_req.stop_conditions.max_tokens = Some(1);
let prefill_context = Context::with_id(prefill_req, request_id.clone());
// Link the prefill context as a child so that kill signals propagate
// Try build_bootstrap_info optimization
let prefill_result = if let Some((worker_id, dp_rank, bootstrap_info)) =
self.build_bootstrap_info(&prefill_req).await
{
let bootstrap_room = bootstrap_info.bootstrap_room;
// Prepare request with bootstrap_room and force routing to specific worker
prefill_req.backend_instance_id = Some(worker_id);
prefill_req.dp_rank = Some(dp_rank);
let extra_args = prefill_req
.extra_args
.get_or_insert_with(|| serde_json::json!({}));
if let Some(obj) = extra_args.as_object_mut() {
obj.insert(
"bootstrap_room".to_string(),
serde_json::json!(bootstrap_room),
);
}
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
let prefill_request = prefill_context;
self.spawn_prefill_task(prefill_context);
Ok((None, Some(worker_id), Some(bootstrap_info)))
} else {
// Fallback to original: Wait for prefill to complete
tracing::debug!("Using original prefill path");
// Attempt prefill
let prefill_result = self.call_prefill(prefill_request).await;
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
self.call_prefill(prefill_context)
.await
.map(|(result, worker_id)| (Some(result), worker_id, None))
};
// Abort if cancelled during prefill
if engine_ctx.is_stopped() || engine_ctx.is_killed() {
......@@ -299,15 +424,24 @@ impl
// Handle prefill result
match prefill_result {
Ok(prefill_result) => {
tracing::debug!("Prefill succeeded, using disaggregated params for decode");
Ok((maybe_prefill_result, _prefill_worker_id, bootstrap_info)) => {
tracing::debug!("Prefill completed, proceeding to decode");
let mut decode_req = req;
// Update request with prefill result
decode_req.prefill_result = Some(prefill_result.clone());
// Update request with prefill result if available (only in original path)
if let Some(prefill_result) = maybe_prefill_result {
decode_req.prefill_result = Some(prefill_result);
}
// Restore original max_tokens for decode
decode_req.stop_conditions.max_tokens = original_max_tokens;
// Inject bootstrap_info for decode worker
if let Some(info) = bootstrap_info {
decode_req.bootstrap_info = Some(info);
}
// Set router_config_override for decode: overlap_score_weight = 0
let existing_override = decode_req.router_config_override.take();
decode_req.router_config_override = Some(RouterConfigOverride {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::local_model::runtime_config::ModelRuntimeConfig;
use crate::local_model::runtime_config::{DisaggregatedEndpoint, ModelRuntimeConfig};
use anyhow::Result;
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider;
......@@ -90,6 +90,8 @@ impl SchedulingRequest {
pub struct KvScheduler {
request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
slots: Arc<ActiveSequencesMultiWorker>,
/// Worker runtime configs for looking up disaggregated endpoints
workers_with_configs: Arc<RwLock<HashMap<WorkerId, Option<ModelRuntimeConfig>>>>,
}
impl KvScheduler {
......@@ -287,7 +289,11 @@ impl KvScheduler {
tracing::trace!("background endpoint subscriber shutting down");
});
Ok(KvScheduler { request_tx, slots })
Ok(KvScheduler {
request_tx,
slots,
workers_with_configs,
})
}
pub async fn schedule(
......@@ -346,6 +352,17 @@ impl KvScheduler {
self.slots.free(&request_id.to_string()).await
}
pub async fn get_disaggregated_endpoint(
&self,
worker_id: WorkerId,
) -> Option<DisaggregatedEndpoint> {
let workers = self.workers_with_configs.read().await;
workers
.get(&worker_id)
.and_then(|config| config.as_ref())
.and_then(|config| config.disaggregated_endpoint.clone())
}
pub async fn get_potential_loads(
&self,
token_seq: Option<Vec<SequenceHash>>,
......
......@@ -7,6 +7,15 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned};
use crate::protocols::tensor;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
pub struct DisaggregatedEndpoint {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub bootstrap_host: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub bootstrap_port: Option<u16>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ModelRuntimeConfig {
pub total_kv_blocks: Option<u64>,
......@@ -40,6 +49,10 @@ pub struct ModelRuntimeConfig {
// doesn't provide JSON parsing.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tensor_model_config: Option<tensor::TensorModelConfig>,
/// Bootstrap endpoint for disaggregated serving (prefill workers publish this)
#[serde(default, skip_serializing_if = "Option::is_none")]
pub disaggregated_endpoint: Option<DisaggregatedEndpoint>,
}
const fn default_data_parallel_size() -> u32 {
......@@ -58,6 +71,7 @@ impl Default for ModelRuntimeConfig {
enable_local_indexer: false,
runtime_data: HashMap::new(),
tensor_model_config: None,
disaggregated_endpoint: None,
}
}
}
......
......@@ -10,6 +10,18 @@ use crate::kv_router::RouterConfigOverride;
use crate::preprocessor::media::RdmaMediaDataDescriptor;
use crate::protocols::TokenIdType;
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct BootstrapInfo {
/// The host address for bootstrap connection
pub bootstrap_host: String,
/// The port for bootstrap connection
pub bootstrap_port: u16,
/// Unique room ID for this request's KV transfer session
pub bootstrap_room: u64,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PrefillResult {
/// Disaggregated execution parameters
......@@ -87,6 +99,11 @@ pub struct PreprocessedRequest {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prefill_result: Option<PrefillResult>,
/// Bootstrap info for disaggregated serving
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub bootstrap_info: Option<BootstrapInfo>,
/// Data parallel rank for the request (used with data parallelism)
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment