"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "8e35ef0142cb8445c608105d06c53594085f8aed"
Unverified Commit 5e7fdc79 authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[OAI Server Refactor] [ChatCompletions & Completions] Support Return Hidden State (#7329)


Signed-off-by: default avatarkeru <rukeyang@gmail.com>
parent 4d8d9b8e
...@@ -16,7 +16,13 @@ ...@@ -16,7 +16,13 @@
import time import time
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field, field_validator, model_validator from pydantic import (
BaseModel,
Field,
field_validator,
model_serializer,
model_validator,
)
from typing_extensions import Literal from typing_extensions import Literal
...@@ -167,6 +173,7 @@ class CompletionRequest(BaseModel): ...@@ -167,6 +173,7 @@ class CompletionRequest(BaseModel):
temperature: float = 1.0 temperature: float = 1.0
top_p: float = 1.0 top_p: float = 1.0
user: Optional[str] = None user: Optional[str] = None
return_hidden_states: bool = False
# Extra parameters for SRT backend only and will be ignored by OpenAI models. # Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: int = -1 top_k: int = -1
...@@ -202,6 +209,14 @@ class CompletionResponseChoice(BaseModel): ...@@ -202,6 +209,14 @@ class CompletionResponseChoice(BaseModel):
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Literal["stop", "length", "content_filter", "abort"] finish_reason: Literal["stop", "length", "content_filter", "abort"]
matched_stop: Union[None, int, str] = None matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer(mode="wrap")
def _serialize(self, handler):
data = handler(self)
if self.hidden_states is None:
data.pop("hidden_states", None)
return data
class CompletionResponse(BaseModel): class CompletionResponse(BaseModel):
...@@ -219,6 +234,14 @@ class CompletionResponseStreamChoice(BaseModel): ...@@ -219,6 +234,14 @@ class CompletionResponseStreamChoice(BaseModel):
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
matched_stop: Union[None, int, str] = None matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer(mode="wrap")
def _serialize(self, handler):
data = handler(self)
if self.hidden_states is None:
data.pop("hidden_states", None)
return data
class CompletionStreamResponse(BaseModel): class CompletionStreamResponse(BaseModel):
...@@ -376,6 +399,7 @@ class ChatCompletionRequest(BaseModel): ...@@ -376,6 +399,7 @@ class ChatCompletionRequest(BaseModel):
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field( tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
default="auto", examples=["none"] default="auto", examples=["none"]
) # noqa ) # noqa
return_hidden_states: bool = False
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
...@@ -437,6 +461,14 @@ class ChatCompletionResponseChoice(BaseModel): ...@@ -437,6 +461,14 @@ class ChatCompletionResponseChoice(BaseModel):
"stop", "length", "tool_calls", "content_filter", "function_call", "abort" "stop", "length", "tool_calls", "content_filter", "function_call", "abort"
] ]
matched_stop: Union[None, int, str] = None matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@model_serializer(mode="wrap")
def _serialize(self, handler):
data = handler(self)
if self.hidden_states is None:
data.pop("hidden_states", None)
return data
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):
...@@ -453,6 +485,14 @@ class DeltaMessage(BaseModel): ...@@ -453,6 +485,14 @@ class DeltaMessage(BaseModel):
content: Optional[str] = None content: Optional[str] = None
reasoning_content: Optional[str] = None reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
hidden_states: Optional[object] = None
@model_serializer(mode="wrap")
def _serialize(self, handler):
data = handler(self)
if self.hidden_states is None:
data.pop("hidden_states", None)
return data
class ChatCompletionResponseStreamChoice(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel):
......
...@@ -30,6 +30,7 @@ from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor ...@@ -30,6 +30,7 @@ from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
from sglang.srt.entrypoints.openai.utils import ( from sglang.srt.entrypoints.openai.utils import (
detect_template_content_format, detect_template_content_format,
process_content_for_template_format, process_content_for_template_format,
process_hidden_states_from_ret,
to_openai_style_logprobs, to_openai_style_logprobs,
) )
from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.function_call.function_call_parser import FunctionCallParser
...@@ -99,6 +100,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -99,6 +100,7 @@ class OpenAIServingChat(OpenAIServingBase):
bootstrap_host=request.bootstrap_host, bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port, bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room, bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
) )
return adapted_request, request return adapted_request, request
...@@ -402,6 +404,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -402,6 +404,7 @@ class OpenAIServingChat(OpenAIServingBase):
prompt_tokens = {} prompt_tokens = {}
completion_tokens = {} completion_tokens = {}
cached_tokens = {} cached_tokens = {}
hidden_states = {}
try: try:
async for content in self.tokenizer_manager.generate_request( async for content in self.tokenizer_manager.generate_request(
...@@ -412,6 +415,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -412,6 +415,7 @@ class OpenAIServingChat(OpenAIServingBase):
prompt_tokens[index] = content["meta_info"]["prompt_tokens"] prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"] completion_tokens[index] = content["meta_info"]["completion_tokens"]
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0) cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
hidden_states[index] = content["meta_info"].get("hidden_states", None)
# Handle logprobs # Handle logprobs
choice_logprobs = None choice_logprobs = None
...@@ -544,6 +548,31 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -544,6 +548,31 @@ class OpenAIServingChat(OpenAIServingBase):
) )
yield f"data: {finish_reason_chunk.model_dump_json()}\n\n" yield f"data: {finish_reason_chunk.model_dump_json()}\n\n"
# Send hidden states if requested
if request.return_hidden_states and hidden_states:
for index, choice_hidden_states in hidden_states.items():
if choice_hidden_states:
last_token_hidden_states = (
choice_hidden_states[-1]
if len(choice_hidden_states) > 1
else []
)
hidden_states_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[
ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(
hidden_states=last_token_hidden_states
),
finish_reason=finish_reason_type,
)
],
model=request.model,
)
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
# Additional usage chunk # Additional usage chunk
if request.stream_options and request.stream_options.include_usage: if request.stream_options and request.stream_options.include_usage:
usage = UsageProcessor.calculate_streaming_usage( usage = UsageProcessor.calculate_streaming_usage(
...@@ -608,6 +637,9 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -608,6 +637,9 @@ class OpenAIServingChat(OpenAIServingBase):
if request.logprobs: if request.logprobs:
choice_logprobs = self._process_response_logprobs(ret_item) choice_logprobs = self._process_response_logprobs(ret_item)
# Handle hidden states
hidden_states = process_hidden_states_from_ret(ret_item, request)
finish_reason = ret_item["meta_info"]["finish_reason"] finish_reason = ret_item["meta_info"]["finish_reason"]
text = ret_item["text"] text = ret_item["text"]
...@@ -654,6 +686,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -654,6 +686,7 @@ class OpenAIServingChat(OpenAIServingBase):
if finish_reason and "matched" in finish_reason if finish_reason and "matched" in finish_reason
else None else None
), ),
hidden_states=hidden_states,
) )
choices.append(choice_data) choices.append(choice_data)
......
...@@ -19,7 +19,10 @@ from sglang.srt.entrypoints.openai.protocol import ( ...@@ -19,7 +19,10 @@ from sglang.srt.entrypoints.openai.protocol import (
) )
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
from sglang.srt.entrypoints.openai.utils import to_openai_style_logprobs from sglang.srt.entrypoints.openai.utils import (
process_hidden_states_from_ret,
to_openai_style_logprobs,
)
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -76,6 +79,7 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -76,6 +79,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
bootstrap_host=request.bootstrap_host, bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port, bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room, bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
) )
return adapted_request, request return adapted_request, request
...@@ -188,6 +192,7 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -188,6 +192,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
delta = text[len(stream_buffer) :] delta = text[len(stream_buffer) :]
stream_buffers[index] = stream_buffer + delta stream_buffers[index] = stream_buffer + delta
finish_reason = content["meta_info"]["finish_reason"] finish_reason = content["meta_info"]["finish_reason"]
hidden_states = content["meta_info"].get("hidden_states", None)
choice_data = CompletionResponseStreamChoice( choice_data = CompletionResponseStreamChoice(
index=index, index=index,
...@@ -210,6 +215,30 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -210,6 +215,30 @@ class OpenAIServingCompletion(OpenAIServingBase):
yield f"data: {chunk.model_dump_json()}\n\n" yield f"data: {chunk.model_dump_json()}\n\n"
if request.return_hidden_states and hidden_states:
for index, choice_hidden_states in hidden_states.items():
if choice_hidden_states:
last_token_hidden_states = (
choice_hidden_states[-1]
if len(choice_hidden_states) > 1
else []
)
hidden_states_chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
object="text_completion",
choices=[
CompletionResponseStreamChoice(
index=index,
text="",
hidden_states=last_token_hidden_states,
finish_reason=None,
)
],
model=request.model,
)
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"
# Handle final usage chunk # Handle final usage chunk
if request.stream_options and request.stream_options.include_usage: if request.stream_options and request.stream_options.include_usage:
usage = UsageProcessor.calculate_streaming_usage( usage = UsageProcessor.calculate_streaming_usage(
...@@ -304,6 +333,9 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -304,6 +333,9 @@ class OpenAIServingCompletion(OpenAIServingBase):
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"], output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
) )
# Handle hidden states
hidden_states = process_hidden_states_from_ret(ret_item, request)
finish_reason = ret_item["meta_info"]["finish_reason"] finish_reason = ret_item["meta_info"]["finish_reason"]
choice_data = CompletionResponseChoice( choice_data = CompletionResponseChoice(
...@@ -316,6 +348,7 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -316,6 +348,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
if finish_reason and "matched" in finish_reason if finish_reason and "matched" in finish_reason
else None else None
), ),
hidden_states=hidden_states,
) )
choices.append(choice_data) choices.append(choice_data)
......
import logging import logging
from typing import Any, Dict, List, Optional, Union
import jinja2.nodes import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils import transformers.utils.chat_template_utils as hf_chat_utils
from sglang.srt.entrypoints.openai.protocol import LogProbs from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
LogProbs,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -205,3 +210,28 @@ def to_openai_style_logprobs( ...@@ -205,3 +210,28 @@ def to_openai_style_logprobs(
append_top_logprobs(output_top_logprobs) append_top_logprobs(output_top_logprobs)
return ret_logprobs return ret_logprobs
def process_hidden_states_from_ret(
ret_item: Dict[str, Any],
request: Union[
ChatCompletionRequest,
CompletionRequest,
],
) -> Optional[List]:
"""Process hidden states from a ret item in non-streaming response.
Args:
ret_item: Response item containing meta_info
request: The original request object
Returns:
Processed hidden states for the last token, or None
"""
if not request.return_hidden_states:
return None
hidden_states = ret_item["meta_info"].get("hidden_states", None)
if hidden_states is not None:
hidden_states = hidden_states[-1] if len(hidden_states) > 1 else []
return hidden_states
...@@ -632,6 +632,51 @@ class TestStreamingModels(unittest.TestCase): ...@@ -632,6 +632,51 @@ class TestStreamingModels(unittest.TestCase):
self.assertEqual(response.choices[0].delta.content, "Hello") self.assertEqual(response.choices[0].delta.content, "Hello")
class TestModelSerialization(unittest.TestCase):
"""Test model serialization with hidden states"""
def test_hidden_states_excluded_when_none(self):
"""Test that None hidden_states are excluded with exclude_none=True"""
choice = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content="Hello"),
finish_reason="stop",
hidden_states=None,
)
response = ChatCompletionResponse(
id="test-id",
model="test-model",
choices=[choice],
usage=UsageInfo(prompt_tokens=5, completion_tokens=1, total_tokens=6),
)
# Test exclude_none serialization (should exclude None hidden_states)
data = response.model_dump(exclude_none=True)
self.assertNotIn("hidden_states", data["choices"][0])
def test_hidden_states_included_when_not_none(self):
"""Test that non-None hidden_states are included"""
choice = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content="Hello"),
finish_reason="stop",
hidden_states=[0.1, 0.2, 0.3],
)
response = ChatCompletionResponse(
id="test-id",
model="test-model",
choices=[choice],
usage=UsageInfo(prompt_tokens=5, completion_tokens=1, total_tokens=6),
)
# Test exclude_none serialization (should include non-None hidden_states)
data = response.model_dump(exclude_none=True)
self.assertIn("hidden_states", data["choices"][0])
self.assertEqual(data["choices"][0]["hidden_states"], [0.1, 0.2, 0.3])
class TestValidationEdgeCases(unittest.TestCase): class TestValidationEdgeCases(unittest.TestCase):
"""Test edge cases and validation scenarios""" """Test edge cases and validation scenarios"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment