Unverified Commit 7053bb25 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

fix(sglang): restore fallback for disaggregated decode without --router-mode kv (#5075)


Signed-off-by: default avatarAnant Sharma <anants@nvidia.com>
Co-authored-by: default avatarAnant Sharma <anants@nvidia.com>
parent 01f77f2c
......@@ -137,6 +137,17 @@ async def init(runtime: DistributedRuntime, config: Config):
"Registered engine routes: /engine/start_profile, /engine/stop_profile"
)
# Create prefill client for disaggregated decode mode (fallback when --router-mode kv is not used)
prefill_client = None
if config.serving_mode == DisaggregationMode.DECODE:
logging.info("Initializing prefill client for disaggregated decode worker")
prefill_client = (
await runtime.namespace(dynamo_args.namespace)
.component("prefill")
.endpoint("generate")
.client()
)
# publisher instantiates the metrics and kv event publishers
publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint
......@@ -149,7 +160,7 @@ async def init(runtime: DistributedRuntime, config: Config):
# Readiness gate: requests wait until model is registered
ready_event = asyncio.Event()
handler = DecodeWorkerHandler(component, engine, config, publisher)
handler = DecodeWorkerHandler(component, engine, config, publisher, prefill_client)
print(f"Config: {config}")
health_check_payload = SglangHealthCheckPayload(
engine, use_text_input=dynamo_args.use_sglang_tokenizer
......
......@@ -4,18 +4,16 @@
import asyncio
import logging
import time
from typing import Any, AsyncGenerator, Dict
from typing import Any, AsyncGenerator, Dict, Optional
import sglang as sgl
from dynamo._core import Component, Context
from dynamo._core import Client, Component, Context
from dynamo.sglang.args import Config, DisaggregationMode
from dynamo.sglang.protocol import DisaggPreprocessedRequest
from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
# Timeout for decode engine to receive first response when waiting for KV cache transfer
DECODE_KV_TRANSFER_TIMEOUT_SECONDS = 60.0
class DecodeWorkerHandler(BaseWorkerHandler):
"""Handler for decode workers in both aggregated and disaggregated serving modes."""
......@@ -26,6 +24,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
engine: sgl.Engine,
config: Config,
publisher: DynamoSglangPublisher,
prefill_client: Optional[Client] = None,
) -> None:
"""Initialize decode worker handler.
......@@ -34,12 +33,14 @@ class DecodeWorkerHandler(BaseWorkerHandler):
engine: The SGLang engine instance.
config: SGLang and Dynamo configuration.
publisher: Metrics publisher for the worker.
prefill_client: Optional client for prefill worker in disaggregated mode.
"""
super().__init__(
component,
engine,
config,
publisher,
prefill_client,
)
if self.serving_mode == DisaggregationMode.DECODE:
logging.info(
......@@ -107,16 +108,52 @@ class DecodeWorkerHandler(BaseWorkerHandler):
input_param = self._get_input_param(request)
if self.serving_mode == DisaggregationMode.DECODE:
# Check if bootstrap_info is in the request
# Check if bootstrap_info is pre-computed in the request (from frontend with --router-mode kv)
bootstrap_info = request.get("bootstrap_info")
if not bootstrap_info:
# Fallback: fetch bootstrap_info from prefill worker via round-robin routing
if self.prefill_client is None:
raise RuntimeError(
"bootstrap_info is required for disaggregated decode but was not provided, "
"and no prefill_client is available for fallback."
)
logging.debug(
"No bootstrap_info in request, fetching from prefill worker"
)
prefill_stream = await self.prefill_client.generate(
DisaggPreprocessedRequest(
request=request,
sampling_params=sampling_params,
).model_dump(),
context=context,
)
prefill_response = None
async for info in prefill_stream:
prefill_response = info.data()
break
if not prefill_response:
raise RuntimeError("No response received from prefill worker")
# Extract bootstrap_info from disaggregated_params (PrefillWorkerHandler format)
bootstrap_info = prefill_response.get("disaggregated_params")
if not bootstrap_info:
raise RuntimeError(
"bootstrap_info is required for disaggregated decode but was not provided."
"No bootstrap info (disaggregated_params) received from prefill worker"
)
logging.debug(
f"Using bootstrap_info: "
f"Received bootstrap_info from prefill worker: "
f"host={bootstrap_info['bootstrap_host']}, "
f"port={bootstrap_info['bootstrap_port']}, "
f"room={bootstrap_info['bootstrap_room']}"
)
else:
logging.debug(
f"Using pre-computed bootstrap_info: "
f"host={bootstrap_info['bootstrap_host']}, "
f"port={bootstrap_info['bootstrap_port']}, "
f"room={bootstrap_info['bootstrap_room']}"
......@@ -137,28 +174,11 @@ class DecodeWorkerHandler(BaseWorkerHandler):
rid=trace_id,
)
# Wait for first token with timeout
decode_iter = decode.__aiter__()
try:
first_res = await asyncio.wait_for(
decode_iter.__anext__(), timeout=DECODE_KV_TRANSFER_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
raise RuntimeError(
f"Decode timed out after {DECODE_KV_TRANSFER_TIMEOUT_SECONDS}s waiting for first token. "
)
# Create stream starting with first result
async def decode_stream() -> AsyncGenerator[Dict[str, Any], None]:
yield first_res
async for res in decode_iter:
yield res
if self.skip_tokenizer_init:
async for out in self._process_token_stream(decode_stream(), context):
async for out in self._process_token_stream(decode, context):
yield out
else:
async for out in self._process_text_stream(decode_stream(), context):
async for out in self._process_text_stream(decode, context):
yield out
else:
if self.enable_trace:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment