Unverified Commit 7053bb25 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

fix(sglang): restore fallback for disaggregated decode without --router-mode kv (#5075)


Signed-off-by: default avatarAnant Sharma <anants@nvidia.com>
Co-authored-by: default avatarAnant Sharma <anants@nvidia.com>
parent 01f77f2c
...@@ -137,6 +137,17 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -137,6 +137,17 @@ async def init(runtime: DistributedRuntime, config: Config):
"Registered engine routes: /engine/start_profile, /engine/stop_profile" "Registered engine routes: /engine/start_profile, /engine/stop_profile"
) )
# Create prefill client for disaggregated decode mode (fallback when --router-mode kv is not used)
prefill_client = None
if config.serving_mode == DisaggregationMode.DECODE:
logging.info("Initializing prefill client for disaggregated decode worker")
prefill_client = (
await runtime.namespace(dynamo_args.namespace)
.component("prefill")
.endpoint("generate")
.client()
)
# publisher instantiates the metrics and kv event publishers # publisher instantiates the metrics and kv event publishers
publisher, metrics_task, metrics_labels = await setup_sgl_metrics( publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint engine, config, component, generate_endpoint
...@@ -149,7 +160,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -149,7 +160,7 @@ async def init(runtime: DistributedRuntime, config: Config):
# Readiness gate: requests wait until model is registered # Readiness gate: requests wait until model is registered
ready_event = asyncio.Event() ready_event = asyncio.Event()
handler = DecodeWorkerHandler(component, engine, config, publisher) handler = DecodeWorkerHandler(component, engine, config, publisher, prefill_client)
print(f"Config: {config}") print(f"Config: {config}")
health_check_payload = SglangHealthCheckPayload( health_check_payload = SglangHealthCheckPayload(
engine, use_text_input=dynamo_args.use_sglang_tokenizer engine, use_text_input=dynamo_args.use_sglang_tokenizer
......
...@@ -4,18 +4,16 @@ ...@@ -4,18 +4,16 @@
import asyncio import asyncio
import logging import logging
import time import time
from typing import Any, AsyncGenerator, Dict from typing import Any, AsyncGenerator, Dict, Optional
import sglang as sgl import sglang as sgl
from dynamo._core import Component, Context from dynamo._core import Client, Component, Context
from dynamo.sglang.args import Config, DisaggregationMode from dynamo.sglang.args import Config, DisaggregationMode
from dynamo.sglang.protocol import DisaggPreprocessedRequest
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
# Timeout for decode engine to receive first response when waiting for KV cache transfer
DECODE_KV_TRANSFER_TIMEOUT_SECONDS = 60.0
class DecodeWorkerHandler(BaseWorkerHandler): class DecodeWorkerHandler(BaseWorkerHandler):
"""Handler for decode workers in both aggregated and disaggregated serving modes.""" """Handler for decode workers in both aggregated and disaggregated serving modes."""
...@@ -26,6 +24,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -26,6 +24,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
engine: sgl.Engine, engine: sgl.Engine,
config: Config, config: Config,
publisher: DynamoSglangPublisher, publisher: DynamoSglangPublisher,
prefill_client: Optional[Client] = None,
) -> None: ) -> None:
"""Initialize decode worker handler. """Initialize decode worker handler.
...@@ -34,12 +33,14 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -34,12 +33,14 @@ class DecodeWorkerHandler(BaseWorkerHandler):
engine: The SGLang engine instance. engine: The SGLang engine instance.
config: SGLang and Dynamo configuration. config: SGLang and Dynamo configuration.
publisher: Metrics publisher for the worker. publisher: Metrics publisher for the worker.
prefill_client: Optional client for prefill worker in disaggregated mode.
""" """
super().__init__( super().__init__(
component, component,
engine, engine,
config, config,
publisher, publisher,
prefill_client,
) )
if self.serving_mode == DisaggregationMode.DECODE: if self.serving_mode == DisaggregationMode.DECODE:
logging.info( logging.info(
...@@ -107,16 +108,52 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -107,16 +108,52 @@ class DecodeWorkerHandler(BaseWorkerHandler):
input_param = self._get_input_param(request) input_param = self._get_input_param(request)
if self.serving_mode == DisaggregationMode.DECODE: if self.serving_mode == DisaggregationMode.DECODE:
# Check if bootstrap_info is in the request # Check if bootstrap_info is pre-computed in the request (from frontend with --router-mode kv)
bootstrap_info = request.get("bootstrap_info") bootstrap_info = request.get("bootstrap_info")
if not bootstrap_info:
# Fallback: fetch bootstrap_info from prefill worker via round-robin routing
if self.prefill_client is None:
raise RuntimeError(
"bootstrap_info is required for disaggregated decode but was not provided, "
"and no prefill_client is available for fallback."
)
logging.debug(
"No bootstrap_info in request, fetching from prefill worker"
)
prefill_stream = await self.prefill_client.generate(
DisaggPreprocessedRequest(
request=request,
sampling_params=sampling_params,
).model_dump(),
context=context,
)
prefill_response = None
async for info in prefill_stream:
prefill_response = info.data()
break
if not prefill_response:
raise RuntimeError("No response received from prefill worker")
# Extract bootstrap_info from disaggregated_params (PrefillWorkerHandler format)
bootstrap_info = prefill_response.get("disaggregated_params")
if not bootstrap_info: if not bootstrap_info:
raise RuntimeError( raise RuntimeError(
"bootstrap_info is required for disaggregated decode but was not provided." "No bootstrap info (disaggregated_params) received from prefill worker"
) )
logging.debug( logging.debug(
f"Using bootstrap_info: " f"Received bootstrap_info from prefill worker: "
f"host={bootstrap_info['bootstrap_host']}, "
f"port={bootstrap_info['bootstrap_port']}, "
f"room={bootstrap_info['bootstrap_room']}"
)
else:
logging.debug(
f"Using pre-computed bootstrap_info: "
f"host={bootstrap_info['bootstrap_host']}, " f"host={bootstrap_info['bootstrap_host']}, "
f"port={bootstrap_info['bootstrap_port']}, " f"port={bootstrap_info['bootstrap_port']}, "
f"room={bootstrap_info['bootstrap_room']}" f"room={bootstrap_info['bootstrap_room']}"
...@@ -137,28 +174,11 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -137,28 +174,11 @@ class DecodeWorkerHandler(BaseWorkerHandler):
rid=trace_id, rid=trace_id,
) )
# Wait for first token with timeout
decode_iter = decode.__aiter__()
try:
first_res = await asyncio.wait_for(
decode_iter.__anext__(), timeout=DECODE_KV_TRANSFER_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
raise RuntimeError(
f"Decode timed out after {DECODE_KV_TRANSFER_TIMEOUT_SECONDS}s waiting for first token. "
)
# Create stream starting with first result
async def decode_stream() -> AsyncGenerator[Dict[str, Any], None]:
yield first_res
async for res in decode_iter:
yield res
if self.skip_tokenizer_init: if self.skip_tokenizer_init:
async for out in self._process_token_stream(decode_stream(), context): async for out in self._process_token_stream(decode, context):
yield out yield out
else: else:
async for out in self._process_text_stream(decode_stream(), context): async for out in self._process_text_stream(decode, context):
yield out yield out
else: else:
if self.enable_trace: if self.enable_trace:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment