Unverified Commit 5e7fdc79 authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[OAI Server Refactor] [ChatCompletions & Completions] Support Return Hidden State (#7329)


Signed-off-by: default avatarkeru <rukeyang@gmail.com>
parent 4d8d9b8e
......@@ -16,7 +16,13 @@
import time
from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field, field_validator, model_validator
from pydantic import (
BaseModel,
Field,
field_validator,
model_serializer,
model_validator,
)
from typing_extensions import Literal
......@@ -167,6 +173,7 @@ class CompletionRequest(BaseModel):
temperature: float = 1.0
top_p: float = 1.0
user: Optional[str] = None
return_hidden_states: bool = False
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: int = -1
......@@ -202,6 +209,14 @@ class CompletionResponseChoice(BaseModel):
logprobs: Optional[LogProbs] = None
finish_reason: Literal["stop", "length", "content_filter", "abort"]
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer(mode="wrap")
def _serialize(self, handler):
data = handler(self)
if self.hidden_states is None:
data.pop("hidden_states", None)
return data
class CompletionResponse(BaseModel):
......@@ -219,6 +234,14 @@ class CompletionResponseStreamChoice(BaseModel):
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer(mode="wrap")
def _serialize(self, handler):
data = handler(self)
if self.hidden_states is None:
data.pop("hidden_states", None)
return data
class CompletionStreamResponse(BaseModel):
......@@ -376,6 +399,7 @@ class ChatCompletionRequest(BaseModel):
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
default="auto", examples=["none"]
) # noqa
return_hidden_states: bool = False
@model_validator(mode="before")
@classmethod
......@@ -437,6 +461,14 @@ class ChatCompletionResponseChoice(BaseModel):
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
]
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer(mode="wrap")
def _serialize(self, handler):
data = handler(self)
if self.hidden_states is None:
data.pop("hidden_states", None)
return data
class ChatCompletionResponse(BaseModel):
......@@ -453,6 +485,14 @@ class DeltaMessage(BaseModel):
content: Optional[str] = None
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
hidden_states: Optional[object] = None
@model_serializer(mode="wrap")
def _serialize(self, handler):
data = handler(self)
if self.hidden_states is None:
data.pop("hidden_states", None)
return data
class ChatCompletionResponseStreamChoice(BaseModel):
......
......@@ -30,6 +30,7 @@ from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
from sglang.srt.entrypoints.openai.utils import (
detect_template_content_format,
process_content_for_template_format,
process_hidden_states_from_ret,
to_openai_style_logprobs,
)
from sglang.srt.function_call.function_call_parser import FunctionCallParser
......@@ -99,6 +100,7 @@ class OpenAIServingChat(OpenAIServingBase):
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
)
return adapted_request, request
......@@ -402,6 +404,7 @@ class OpenAIServingChat(OpenAIServingBase):
prompt_tokens = {}
completion_tokens = {}
cached_tokens = {}
hidden_states = {}
try:
async for content in self.tokenizer_manager.generate_request(
......@@ -412,6 +415,7 @@ class OpenAIServingChat(OpenAIServingBase):
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"]
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
hidden_states[index] = content["meta_info"].get("hidden_states", None)
# Handle logprobs
choice_logprobs = None
......@@ -544,6 +548,31 @@ class OpenAIServingChat(OpenAIServingBase):
)
yield f"data: {finish_reason_chunk.model_dump_json()}\n\n"
# Send hidden states if requested
if request.return_hidden_states and hidden_states:
for index, choice_hidden_states in hidden_states.items():
if choice_hidden_states:
last_token_hidden_states = (
choice_hidden_states[-1]
if len(choice_hidden_states) > 1
else []
)
hidden_states_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[
ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(
hidden_states=last_token_hidden_states
),
finish_reason=finish_reason_type,
)
],
model=request.model,
)
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
# Additional usage chunk
if request.stream_options and request.stream_options.include_usage:
usage = UsageProcessor.calculate_streaming_usage(
......@@ -608,6 +637,9 @@ class OpenAIServingChat(OpenAIServingBase):
if request.logprobs:
choice_logprobs = self._process_response_logprobs(ret_item)
# Handle hidden states
hidden_states = process_hidden_states_from_ret(ret_item, request)
finish_reason = ret_item["meta_info"]["finish_reason"]
text = ret_item["text"]
......@@ -654,6 +686,7 @@ class OpenAIServingChat(OpenAIServingBase):
if finish_reason and "matched" in finish_reason
else None
),
hidden_states=hidden_states,
)
choices.append(choice_data)
......
......@@ -19,7 +19,10 @@ from sglang.srt.entrypoints.openai.protocol import (
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
from sglang.srt.entrypoints.openai.utils import to_openai_style_logprobs
from sglang.srt.entrypoints.openai.utils import (
process_hidden_states_from_ret,
to_openai_style_logprobs,
)
from sglang.srt.managers.io_struct import GenerateReqInput
logger = logging.getLogger(__name__)
......@@ -76,6 +79,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
)
return adapted_request, request
......@@ -188,6 +192,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
delta = text[len(stream_buffer) :]
stream_buffers[index] = stream_buffer + delta
finish_reason = content["meta_info"]["finish_reason"]
hidden_states = content["meta_info"].get("hidden_states", None)
choice_data = CompletionResponseStreamChoice(
index=index,
......@@ -210,6 +215,30 @@ class OpenAIServingCompletion(OpenAIServingBase):
yield f"data: {chunk.model_dump_json()}\n\n"
if request.return_hidden_states and hidden_states:
for index, choice_hidden_states in hidden_states.items():
if choice_hidden_states:
last_token_hidden_states = (
choice_hidden_states[-1]
if len(choice_hidden_states) > 1
else []
)
hidden_states_chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
object="text_completion",
choices=[
CompletionResponseStreamChoice(
index=index,
text="",
hidden_states=last_token_hidden_states,
finish_reason=None,
)
],
model=request.model,
)
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
# Handle final usage chunk
if request.stream_options and request.stream_options.include_usage:
usage = UsageProcessor.calculate_streaming_usage(
......@@ -304,6 +333,9 @@ class OpenAIServingCompletion(OpenAIServingBase):
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
)
# Handle hidden states
hidden_states = process_hidden_states_from_ret(ret_item, request)
finish_reason = ret_item["meta_info"]["finish_reason"]
choice_data = CompletionResponseChoice(
......@@ -316,6 +348,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
if finish_reason and "matched" in finish_reason
else None
),
hidden_states=hidden_states,
)
choices.append(choice_data)
......
import logging
from typing import Any, Dict, List, Optional, Union
import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils
from sglang.srt.entrypoints.openai.protocol import LogProbs
from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
LogProbs,
)
logger = logging.getLogger(__name__)
......@@ -205,3 +210,28 @@ def to_openai_style_logprobs(
append_top_logprobs(output_top_logprobs)
return ret_logprobs
def process_hidden_states_from_ret(
ret_item: Dict[str, Any],
request: Union[
ChatCompletionRequest,
CompletionRequest,
],
) -> Optional[List]:
"""Process hidden states from a ret item in non-streaming response.
Args:
ret_item: Response item containing meta_info
request: The original request object
Returns:
Processed hidden states for the last token, or None
"""
if not request.return_hidden_states:
return None
hidden_states = ret_item["meta_info"].get("hidden_states", None)
if hidden_states is not None:
hidden_states = hidden_states[-1] if len(hidden_states) > 1 else []
return hidden_states
......@@ -632,6 +632,51 @@ class TestStreamingModels(unittest.TestCase):
self.assertEqual(response.choices[0].delta.content, "Hello")
class TestModelSerialization(unittest.TestCase):
"""Test model serialization with hidden states"""
def test_hidden_states_excluded_when_none(self):
"""Test that None hidden_states are excluded with exclude_none=True"""
choice = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content="Hello"),
finish_reason="stop",
hidden_states=None,
)
response = ChatCompletionResponse(
id="test-id",
model="test-model",
choices=[choice],
usage=UsageInfo(prompt_tokens=5, completion_tokens=1, total_tokens=6),
)
# Test exclude_none serialization (should exclude None hidden_states)
data = response.model_dump(exclude_none=True)
self.assertNotIn("hidden_states", data["choices"][0])
def test_hidden_states_included_when_not_none(self):
"""Test that non-None hidden_states are included"""
choice = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content="Hello"),
finish_reason="stop",
hidden_states=[0.1, 0.2, 0.3],
)
response = ChatCompletionResponse(
id="test-id",
model="test-model",
choices=[choice],
usage=UsageInfo(prompt_tokens=5, completion_tokens=1, total_tokens=6),
)
# Test exclude_none serialization (should include non-None hidden_states)
data = response.model_dump(exclude_none=True)
self.assertIn("hidden_states", data["choices"][0])
self.assertEqual(data["choices"][0]["hidden_states"], [0.1, 0.2, 0.3])
class TestValidationEdgeCases(unittest.TestCase):
"""Test edge cases and validation scenarios"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment