Unverified Commit 8ce830a8 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][bugfix] Fix input_logprobs handling with None value and `logprob_start_len = -1` (#11113)

parent fb367acf
......@@ -486,6 +486,56 @@ class GrpcRequestManager:
if self.gracefully_exit:
break
def _convert_logprob_style(
self,
state: GrpcReqState,
batch_out: BatchTokenIDOut,
batch_index: int,
):
"""
Convert and accumulate logprobs from batch output to state.
Follows the same logic as tokenizer_manager.convert_logprob_style.
"""
# Early exit if no input logprobs at all
if batch_out.input_token_logprobs_val is None:
return
# Accumulate input token logprobs (only if list is non-empty)
if len(batch_out.input_token_logprobs_val) > 0:
state.input_token_logprobs_val.extend(
batch_out.input_token_logprobs_val[batch_index]
)
state.input_token_logprobs_idx.extend(
batch_out.input_token_logprobs_idx[batch_index]
)
# Always accumulate output token logprobs
state.output_token_logprobs_val.extend(
batch_out.output_token_logprobs_val[batch_index]
)
state.output_token_logprobs_idx.extend(
batch_out.output_token_logprobs_idx[batch_index]
)
# Handle top logprobs if requested
if state.obj.top_logprobs_num > 0:
# Accumulate input top logprobs (only if list is non-empty)
if len(batch_out.input_top_logprobs_val) > 0:
state.input_top_logprobs_val.extend(
batch_out.input_top_logprobs_val[batch_index]
)
state.input_top_logprobs_idx.extend(
batch_out.input_top_logprobs_idx[batch_index]
)
# Always accumulate output top logprobs
state.output_top_logprobs_val.extend(
batch_out.output_top_logprobs_val[batch_index]
)
state.output_top_logprobs_idx.extend(
batch_out.output_top_logprobs_idx[batch_index]
)
async def _handle_batch_output(self, batch_out: BatchTokenIDOut):
"""Handle batch generation output from scheduler."""
# Process each request in the batch
......@@ -526,35 +576,16 @@ class GrpcRequestManager:
},
}
# Accumulate input logprobs (only once, usually in first chunk)
if batch_out.input_token_logprobs_val and i < len(
batch_out.input_token_logprobs_val
):
if not state.input_token_logprobs_val:
state.input_token_logprobs_val.extend(
batch_out.input_token_logprobs_val[i]
)
if batch_out.input_token_logprobs_idx and i < len(
batch_out.input_token_logprobs_idx
):
state.input_token_logprobs_idx.extend(
batch_out.input_token_logprobs_idx[i]
)
if batch_out.input_top_logprobs_val and i < len(
batch_out.input_top_logprobs_val
):
state.input_top_logprobs_val.extend(
batch_out.input_top_logprobs_val[i]
)
if batch_out.input_top_logprobs_idx and i < len(
batch_out.input_top_logprobs_idx
):
state.input_top_logprobs_idx.extend(
batch_out.input_top_logprobs_idx[i]
)
# Accumulate logprobs (following tokenizer_manager pattern)
if state.obj.return_logprob:
self._convert_logprob_style(state, batch_out, i)
# Send input logprobs based on mode
if state.input_token_logprobs_val:
# Send input logprobs based if available
if (
state.obj.return_logprob
and state.obj.logprob_start_len >= 0
and state.input_token_logprobs_val
):
if state.obj.stream and not state.input_logprobs_sent:
# Streaming: send input logprobs once in first chunk that has them
output_data["input_logprobs"] = {
......@@ -573,33 +604,12 @@ class GrpcRequestManager:
"top_logprobs_idx": state.input_top_logprobs_idx,
}
# Add output logprobs if available (RAW - no detokenization!)
if batch_out.output_token_logprobs_val and i < len(
batch_out.output_token_logprobs_val
# Send output logprobs if available
if (
state.obj.return_logprob
and batch_out.output_token_logprobs_val
and i < len(batch_out.output_token_logprobs_val)
):
# Accumulate in state first
state.output_token_logprobs_val.extend(
batch_out.output_token_logprobs_val[i]
)
if batch_out.output_token_logprobs_idx and i < len(
batch_out.output_token_logprobs_idx
):
state.output_token_logprobs_idx.extend(
batch_out.output_token_logprobs_idx[i]
)
if batch_out.output_top_logprobs_val and i < len(
batch_out.output_top_logprobs_val
):
state.output_top_logprobs_val.extend(
batch_out.output_top_logprobs_val[i]
)
if batch_out.output_top_logprobs_idx and i < len(
batch_out.output_top_logprobs_idx
):
state.output_top_logprobs_idx.extend(
batch_out.output_top_logprobs_idx[i]
)
if state.obj.stream:
# For streaming: send incremental logprobs (only new tokens in this chunk)
# NOTE: this is different than TokenizerManager, which always accumulates
......
......@@ -415,7 +415,11 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
mm_inputs=None, # TODO: implement mm support
sampling_params=sampling_params,
return_logprob=grpc_req.return_logprob,
logprob_start_len=grpc_req.logprob_start_len or -1,
logprob_start_len=(
grpc_req.logprob_start_len
if grpc_req.logprob_start_len is not None
else -1
),
top_logprobs_num=grpc_req.top_logprobs_num or 0,
stream=grpc_req.stream or False,
lora_id=grpc_req.lora_id if grpc_req.lora_id else None,
......@@ -486,10 +490,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
ignore_eos=grpc_params.ignore_eos,
)
def _convert_logprobs_to_proto(
def _convert_output_logprobs_to_proto(
self, logprobs_data: Dict
) -> Optional[sglang_scheduler_pb2.LogProbs]:
"""Convert logprobs dict to proto LogProbs format (transport RAW data only)."""
) -> Optional[sglang_scheduler_pb2.OutputLogProbs]:
"""Convert output logprobs dict to proto (no None values, plain floats)."""
if not logprobs_data:
return None
......@@ -509,8 +513,47 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
)
)
return sglang_scheduler_pb2.LogProbs(
token_logprobs=token_logprobs_val,
return sglang_scheduler_pb2.OutputLogProbs(
token_logprobs=token_logprobs_val, # Plain float array
token_ids=token_logprobs_idx,
top_logprobs=top_logprobs_proto,
)
def _convert_input_logprobs_to_proto(
self, logprobs_data: Dict
) -> Optional[sglang_scheduler_pb2.InputLogProbs]:
"""Convert input logprobs dict to proto (first token is None, wrapped in InputTokenLogProb)."""
if not logprobs_data:
return None
token_logprobs_val = logprobs_data.get("token_logprobs_val", [])
token_logprobs_idx = logprobs_data.get("token_logprobs_idx", [])
top_logprobs_val = logprobs_data.get("top_logprobs_val", [])
top_logprobs_idx = logprobs_data.get("top_logprobs_idx", [])
# Wrap values in InputTokenLogProb (None for first token, value for others)
token_logprobs_wrapped = [
(
sglang_scheduler_pb2.InputTokenLogProb()
if x is None
else sglang_scheduler_pb2.InputTokenLogProb(value=x)
)
for x in token_logprobs_val
]
# Build TopLogProbs entries
top_logprobs_proto = []
if top_logprobs_val and top_logprobs_idx:
for val_list, idx_list in zip(top_logprobs_val, top_logprobs_idx):
top_logprobs_proto.append(
sglang_scheduler_pb2.TopLogProbs(
values=val_list,
token_ids=idx_list,
)
)
return sglang_scheduler_pb2.InputLogProbs(
token_logprobs=token_logprobs_wrapped,
token_ids=token_logprobs_idx,
top_logprobs=top_logprobs_proto,
)
......@@ -522,12 +565,12 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
meta_info = output.get("meta_info", {})
# Convert output logprobs if present
output_logprobs_proto = self._convert_logprobs_to_proto(
output_logprobs_proto = self._convert_output_logprobs_to_proto(
output.get("output_logprobs")
)
# Convert input logprobs if present (only in first chunk)
input_logprobs_proto = self._convert_logprobs_to_proto(
input_logprobs_proto = self._convert_input_logprobs_to_proto(
output.get("input_logprobs")
)
......@@ -576,12 +619,12 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
matched_stop_kwargs["matched_stop_str"] = matched
# Convert output logprobs if present
output_logprobs_proto = self._convert_logprobs_to_proto(
output_logprobs_proto = self._convert_output_logprobs_to_proto(
output.get("output_logprobs")
)
# Convert input logprobs if present
input_logprobs_proto = self._convert_logprobs_to_proto(
input_logprobs_proto = self._convert_input_logprobs_to_proto(
output.get("input_logprobs")
)
......
......@@ -175,13 +175,13 @@ message GenerateStreamChunk {
int32 cached_tokens = 4;
// Output logprobs (if requested) - incremental for streaming
LogProbs output_logprobs = 5;
OutputLogProbs output_logprobs = 5;
// Hidden states (if requested)
repeated float hidden_states = 6;
// Input logprobs (if requested) - only in first chunk
LogProbs input_logprobs = 7;
InputLogProbs input_logprobs = 7;
}
message GenerateComplete {
......@@ -197,7 +197,7 @@ message GenerateComplete {
int32 cached_tokens = 5;
// Output logprobs if requested (cumulative)
LogProbs output_logprobs = 6;
OutputLogProbs output_logprobs = 6;
// All hidden states if requested
repeated HiddenStates all_hidden_states = 7;
......@@ -209,7 +209,7 @@ message GenerateComplete {
}
// Input logprobs if requested (for prompt tokens)
LogProbs input_logprobs = 10;
InputLogProbs input_logprobs = 10;
}
message GenerateError {
......@@ -218,7 +218,8 @@ message GenerateError {
string details = 3;
}
message LogProbs {
// Output logprobs - all values are present (no None)
message OutputLogProbs {
repeated float token_logprobs = 1;
repeated int32 token_ids = 2;
......@@ -226,6 +227,20 @@ message LogProbs {
repeated TopLogProbs top_logprobs = 3;
}
// Input logprobs - first token has no logprob (None)
message InputLogProbs {
repeated InputTokenLogProb token_logprobs = 1;
repeated int32 token_ids = 2;
// Top logprobs at each position
repeated TopLogProbs top_logprobs = 3;
}
// Wrapper to represent optional logprob (first input token has no logprob)
message InputTokenLogProb {
optional float value = 1;
}
message TopLogProbs {
repeated float values = 1;
repeated int32 token_ids = 2;
......
......@@ -174,10 +174,10 @@ class GenerateStreamChunk(_message.Message):
prompt_tokens: int
completion_tokens: int
cached_tokens: int
output_logprobs: LogProbs
output_logprobs: OutputLogProbs
hidden_states: _containers.RepeatedScalarFieldContainer[float]
input_logprobs: LogProbs
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ...
input_logprobs: InputLogProbs
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ...) -> None: ...
class GenerateComplete(_message.Message):
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs")
......@@ -196,12 +196,12 @@ class GenerateComplete(_message.Message):
prompt_tokens: int
completion_tokens: int
cached_tokens: int
output_logprobs: LogProbs
output_logprobs: OutputLogProbs
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
matched_token_id: int
matched_stop_str: str
input_logprobs: LogProbs
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ...
input_logprobs: InputLogProbs
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ...) -> None: ...
class GenerateError(_message.Message):
__slots__ = ("message", "http_status_code", "details")
......@@ -213,7 +213,7 @@ class GenerateError(_message.Message):
details: str
def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
class LogProbs(_message.Message):
class OutputLogProbs(_message.Message):
__slots__ = ("token_logprobs", "token_ids", "top_logprobs")
TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
......@@ -223,6 +223,22 @@ class LogProbs(_message.Message):
top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ...
class InputLogProbs(_message.Message):
__slots__ = ("token_logprobs", "token_ids", "top_logprobs")
TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
token_logprobs: _containers.RepeatedCompositeFieldContainer[InputTokenLogProb]
token_ids: _containers.RepeatedScalarFieldContainer[int]
top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
def __init__(self, token_logprobs: _Optional[_Iterable[_Union[InputTokenLogProb, _Mapping]]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ...
class InputTokenLogProb(_message.Message):
__slots__ = ("value",)
VALUE_FIELD_NUMBER: _ClassVar[int]
value: float
def __init__(self, value: _Optional[float] = ...) -> None: ...
class TopLogProbs(_message.Message):
__slots__ = ("values", "token_ids")
VALUES_FIELD_NUMBER: _ClassVar[int]
......
......@@ -175,13 +175,13 @@ message GenerateStreamChunk {
int32 cached_tokens = 4;
// Output logprobs (if requested) - incremental for streaming
LogProbs output_logprobs = 5;
OutputLogProbs output_logprobs = 5;
// Hidden states (if requested)
repeated float hidden_states = 6;
// Input logprobs (if requested) - only in first chunk
LogProbs input_logprobs = 7;
InputLogProbs input_logprobs = 7;
}
message GenerateComplete {
......@@ -197,7 +197,7 @@ message GenerateComplete {
int32 cached_tokens = 5;
// Output logprobs if requested (cumulative)
LogProbs output_logprobs = 6;
OutputLogProbs output_logprobs = 6;
// All hidden states if requested
repeated HiddenStates all_hidden_states = 7;
......@@ -209,7 +209,7 @@ message GenerateComplete {
}
// Input logprobs if requested (for prompt tokens)
LogProbs input_logprobs = 10;
InputLogProbs input_logprobs = 10;
}
message GenerateError {
......@@ -218,7 +218,8 @@ message GenerateError {
string details = 3;
}
message LogProbs {
// Output logprobs - all values are present (no None)
message OutputLogProbs {
repeated float token_logprobs = 1;
repeated int32 token_ids = 2;
......@@ -226,6 +227,20 @@ message LogProbs {
repeated TopLogProbs top_logprobs = 3;
}
// Input logprobs - first token has no logprob (None)
message InputLogProbs {
repeated InputTokenLogProb token_logprobs = 1;
repeated int32 token_ids = 2;
// Top logprobs at each position
repeated TopLogProbs top_logprobs = 3;
}
// Wrapper to represent optional logprob (first input token has no logprob)
message InputTokenLogProb {
optional float value = 1;
}
message TopLogProbs {
repeated float values = 1;
repeated int32 token_ids = 2;
......
......@@ -1239,7 +1239,7 @@ impl GrpcRouter {
/// Note: Always decodes with skip_special_tokens=false to show actual tokens generated
fn convert_proto_to_openai_logprobs(
&self,
proto_logprobs: &proto::LogProbs,
proto_logprobs: &proto::OutputLogProbs,
) -> Result<crate::protocols::spec::ChatLogProbs, String> {
let mut content_items = Vec::new();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment