Unverified Commit 5802db87 authored by William Arnold's avatar William Arnold Committed by GitHub
Browse files

feat(sglang): add logprob passthrough in decode handler (#6837)


Signed-off-by: default avatarWilliam Arnold <7565007+Aphoh@users.noreply.github.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent c53ef736
...@@ -85,6 +85,7 @@ BaseGenerativeHandler (handler_base.py) ...@@ -85,6 +85,7 @@ BaseGenerativeHandler (handler_base.py)
DecodeWorkerHandler (llm/decode_handler.py) DecodeWorkerHandler (llm/decode_handler.py)
Aggregated + disaggregated decode. Token/text streaming. Aggregated + disaggregated decode. Token/text streaming.
Logprob passthrough via _build_logprob_kwargs() + _extract_logprobs().
DiffusionWorkerHandler (llm/diffusion_handler.py) DiffusionWorkerHandler (llm/diffusion_handler.py)
LLM diffusion (DLLM). Simplified decode without disagg. LLM diffusion (DLLM). Simplified decode without disagg.
...@@ -185,12 +186,13 @@ capture SGLang's internal signal registrations and defer them. On SIGTERM/SIGINT ...@@ -185,12 +186,13 @@ capture SGLang's internal signal registrations and defer them. On SIGTERM/SIGINT
``` ```
Frontend (Rust, lib/llm/) Frontend (Rust, lib/llm/)
-> Preprocessor (tokenizes, builds PreprocessedRequest with token_ids + sampling + stop) -> Preprocessor (tokenizes, builds PreprocessedRequest with token_ids + sampling + stop + output_options)
-> Dynamo RPC to endpoint (dyn://{namespace}.{component}.{endpoint}) -> Dynamo RPC to endpoint (dyn://{namespace}.{component}.{endpoint})
-> Python handler.generate(request_dict, context) -> Python handler.generate(request_dict, context)
handler._build_sampling_params(request) -> SGLang-native params handler._build_sampling_params(request) -> SGLang-native params
engine.async_generate(**params) -> async iterator of dicts handler._build_logprob_kwargs(request) -> {return_logprob, top_logprobs_num, logprob_start_len}
handler yields {token_ids, text, finish_reason, ...} back to frontend engine.async_generate(**params, **logprob_kwargs) -> async iterator of dicts
handler yields {token_ids, text, finish_reason, log_probs, top_logprobs, ...} back to frontend
-> Frontend postprocesses into OpenAI-compatible response -> Frontend postprocesses into OpenAI-compatible response
``` ```
...@@ -204,6 +206,36 @@ Image/video diffusion handlers receive the full OpenAI-format request dict direc ...@@ -204,6 +206,36 @@ Image/video diffusion handlers receive the full OpenAI-format request dict direc
(not preprocessed), since the frontend passes through diffusion requests without (not preprocessed), since the frontend passes through diffusion requests without
tokenization. tokenization.
## Logprobs
`DecodeWorkerHandler` supports logprob passthrough, matching the vLLM and TRT-LLM backends.
Controlled by `output_options` in the preprocessed request (from Rust `OutputOptions` struct
in `lib/llm/src/protocols/common.rs`).
**Mapping from OutputOptions to SGLang kwargs** (`_build_logprob_kwargs`):
| OutputOptions field | SGLang kwarg | Notes |
|---------------------|-------------|-------|
| `logprobs: N` | `return_logprob=True, top_logprobs_num=N` | N top logprobs per output token |
| `prompt_logprobs: M` | `return_logprob=True, logprob_start_len=0` | Compute from prompt start |
| Both set | `top_logprobs_num=max(N, M)` | SGLang has a single top_logprobs_num for both |
`logprob_start_len` is SGLang-internal, not exposed in OutputOptions. It controls the
absolute sequence position where logprob computation starts: `-1` (default) = output tokens
only (`len(prompt) - 1`), `0` = from prompt start. We set it to 0 when `prompt_logprobs`
is requested.
**Streaming behavior** (`_extract_logprobs`):
Dynamo forces `stream_output=True` (args.py:374), making `output_ids` disjoint per chunk.
However, SGLang's `meta_info["output_token_logprobs"]` and `meta_info["output_top_logprobs"]`
are always **cumulative** — they grow with each chunk. The handler tracks
`num_output_logprobs_so_far` to slice out only new entries per chunk.
SGLang logprob format: `(logprob, token_id, text_or_None)` tuples.
Dynamo output format: `log_probs` = list of floats, `top_logprobs` = list of lists of
`{rank, token_id, token, logprob}` dicts (same as vLLM/TRT-LLM).
## Health Checks ## Health Checks
Each worker type has a custom health check payload (`health_check.py`): Each worker type has a custom health check payload (`health_check.py`):
...@@ -246,6 +278,9 @@ text-to-video-diffusion.sh # 1-2 GPUs - Text-to-video (Wan2.1) ...@@ -246,6 +278,9 @@ text-to-video-diffusion.sh # 1-2 GPUs - Text-to-video (Wan2.1)
- **output_modalities default**: Global default is `["text"]`. Image/video diffusion - **output_modalities default**: Global default is `["text"]`. Image/video diffusion
workers must override to `["image"]`/`["video"]` or the Rust registration path tries workers must override to `["image"]`/`["video"]` or the Rust registration path tries
to load `config.json` (which doesn't exist for diffusers models). to load `config.json` (which doesn't exist for diffusers models).
- **Cumulative logprobs in streaming**: SGLang's `output_token_logprobs`/`output_top_logprobs`
in `meta_info` are cumulative even though `output_ids` are disjoint (stream_output=True).
Always slice with an offset, don't assume per-chunk logprobs.
- **Zombie GPU processes**: `sgl_diffusion::scheduler` spawns a child process that - **Zombie GPU processes**: `sgl_diffusion::scheduler` spawns a child process that
survives parent kill. Always check `nvidia-smi` after teardown. survives parent kill. Always check `nvidia-smi` after teardown.
......
...@@ -111,6 +111,138 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -111,6 +111,138 @@ class DecodeWorkerHandler(BaseWorkerHandler):
return {k: v for k, v in param_mapping.items() if v is not None} return {k: v for k, v in param_mapping.items() if v is not None}
@staticmethod
def _build_logprob_kwargs(request: Dict[str, Any]) -> Dict[str, Any]:
"""Build logprob kwargs for SGLang async_generate from output_options.
Maps the Dynamo output_options format (shared with vLLM/TRT-LLM) to
SGLang's async_generate keyword arguments:
- return_logprob (bool): enables logprob computation
- top_logprobs_num (int): number of top-k logprobs per token
- logprob_start_len (int): absolute position in the sequence where
logprob computation begins. SGLang defaults this to -1, which
means len(prompt) - 1 (i.e. output tokens only). Setting it to 0
computes logprobs from the start of the prompt — this is how we
implement prompt_logprobs. We don't expose logprob_start_len
directly; it's an SGLang-internal detail derived from whether the
user requested prompt_logprobs.
Args:
request: Request dict containing optional output_options.
Returns:
Dict of logprob-related kwargs for engine.async_generate().
"""
kwargs: Dict[str, Any] = {}
output_options = request.get("output_options", {})
if not output_options:
return kwargs
logprobs_value = output_options.get("logprobs")
if logprobs_value is not None:
try:
parsed = int(logprobs_value)
if parsed < 0:
logging.warning(
f"Invalid logprobs value: {logprobs_value} "
"(must be non-negative), ignoring"
)
else:
kwargs["return_logprob"] = True
kwargs["top_logprobs_num"] = parsed
except (ValueError, TypeError):
logging.warning(
f"Invalid logprobs value: {logprobs_value} "
"(must be integer), ignoring"
)
prompt_logprobs_value = output_options.get("prompt_logprobs")
if prompt_logprobs_value is not None:
try:
parsed = int(prompt_logprobs_value)
if parsed < 0:
logging.warning(
f"Invalid prompt_logprobs value: {prompt_logprobs_value} "
"(must be non-negative), ignoring"
)
else:
kwargs["return_logprob"] = True
# SGLang has a single top_logprobs_num for both prompt
# and output tokens, so take the max of the two.
kwargs["top_logprobs_num"] = max(
kwargs.get("top_logprobs_num", 0), parsed
)
# logprob_start_len=0 computes from prompt start;
# omitting it (or -1) computes output tokens only.
kwargs["logprob_start_len"] = 0
except (ValueError, TypeError):
logging.warning(
f"Invalid prompt_logprobs value: {prompt_logprobs_value} "
"(must be integer), ignoring"
)
return kwargs
@staticmethod
def _extract_logprobs(
meta_info: Dict[str, Any], num_output_logprobs_so_far: int
) -> tuple:
"""Extract logprobs from SGLang meta_info for new tokens.
While Dynamo forces stream_output=True (args.py) so that output_ids
are disjoint per chunk, SGLang's output_token_logprobs and
output_top_logprobs in meta_info are always cumulative. We track an
offset to slice out only the new entries each chunk.
Args:
meta_info: SGLang response meta_info dict.
num_output_logprobs_so_far: Number of logprob entries already
processed in previous chunks.
Returns:
Tuple of (log_probs, top_logprobs, new_total):
- log_probs: List of floats (selected token logprob per position)
- top_logprobs: List of lists of dicts with rank/token_id/token/logprob
- new_total: Updated count of logprob entries processed so far
"""
output_token_logprobs = meta_info.get("output_token_logprobs")
if not output_token_logprobs:
return None, None, num_output_logprobs_so_far
new_logprobs = output_token_logprobs[num_output_logprobs_so_far:]
if not new_logprobs:
return None, None, num_output_logprobs_so_far
# Extract selected-token logprobs: each entry is (logprob, token_id, text_or_None)
log_probs = [float(entry[0]) for entry in new_logprobs]
# Extract top logprobs if available
top_logprobs: list[list[dict[str, Any]]] | None = None
output_top = meta_info.get("output_top_logprobs")
if output_top:
new_top = output_top[num_output_logprobs_so_far:]
if new_top:
top_logprobs = []
for position_entries in new_top:
if position_entries is None:
top_logprobs.append([])
continue
position_list = []
for rank_idx, entry in enumerate(position_entries):
position_list.append(
{
"rank": rank_idx + 1,
"token_id": entry[1],
"token": entry[2],
"logprob": float(entry[0]),
}
)
top_logprobs.append(position_list)
new_total = len(output_token_logprobs)
return log_probs, top_logprobs, new_total
async def generate( async def generate(
self, request: Dict[str, Any], context: Context self, request: Dict[str, Any], context: Context
) -> AsyncGenerator[Dict[str, Any], None]: ) -> AsyncGenerator[Dict[str, Any], None]:
...@@ -134,6 +266,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -134,6 +266,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
self.config.server_args, "enable_return_routed_experts", False self.config.server_args, "enable_return_routed_experts", False
) )
priority = (request.get("routing") or {}).get("priority") priority = (request.get("routing") or {}).get("priority")
logprob_kwargs = self._build_logprob_kwargs(request)
if self.serving_mode == DisaggregationMode.DECODE: if self.serving_mode == DisaggregationMode.DECODE:
# Check if bootstrap_info is pre-computed in the request (from frontend) # Check if bootstrap_info is pre-computed in the request (from frontend)
...@@ -168,6 +301,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -168,6 +301,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
external_trace_header=trace_header, external_trace_header=trace_header,
rid=trace_id, rid=trace_id,
data_parallel_rank=dp_rank, data_parallel_rank=dp_rank,
**logprob_kwargs,
**self._priority_kwargs(priority), **self._priority_kwargs(priority),
) )
...@@ -200,6 +334,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -200,6 +334,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
external_trace_header=trace_header, external_trace_header=trace_header,
rid=trace_id, rid=trace_id,
data_parallel_rank=dp_rank, data_parallel_rank=dp_rank,
**logprob_kwargs,
**self._priority_kwargs(priority), **self._priority_kwargs(priority),
) )
if self.skip_tokenizer_init: if self.skip_tokenizer_init:
...@@ -228,6 +363,9 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -228,6 +363,9 @@ class DecodeWorkerHandler(BaseWorkerHandler):
""" """
# Use Future pattern for request ID - will be set when first response arrives # Use Future pattern for request ID - will be set when first response arrives
request_id_future: asyncio.Future[str] = asyncio.Future() request_id_future: asyncio.Future[str] = asyncio.Future()
# Logprob offset: output_ids are disjoint (stream_output=True) but
# meta_info logprobs are cumulative — track how many we've emitted.
num_output_logprobs_so_far = 0
async with self._cancellation_monitor(request_id_future, context): async with self._cancellation_monitor(request_id_future, context):
async for res in stream_source: async for res in stream_source:
# Extract SGLang request ID from the first response and set the future # Extract SGLang request ID from the first response and set the future
...@@ -260,6 +398,18 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -260,6 +398,18 @@ class DecodeWorkerHandler(BaseWorkerHandler):
# Pass through disjoint token segments directly # Pass through disjoint token segments directly
out["token_ids"] = output_ids out["token_ids"] = output_ids
# Extract logprobs for new tokens if available
(
log_probs,
top_logprobs,
num_output_logprobs_so_far,
) = self._extract_logprobs(res["meta_info"], num_output_logprobs_so_far)
if log_probs is not None:
out["log_probs"] = log_probs
if top_logprobs is not None:
out["top_logprobs"] = top_logprobs
routed_experts = res["meta_info"].get("routed_experts") routed_experts = res["meta_info"].get("routed_experts")
if routed_experts is not None: if routed_experts is not None:
# Base64-encode tensor bytes to match sglang's output format. # Base64-encode tensor bytes to match sglang's output format.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment