Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5e7fdc79
Unverified
Commit
5e7fdc79
authored
Jun 20, 2025
by
Keyang Ru
Committed by
GitHub
Jun 20, 2025
Browse files
[OAI Server Refactor] [ChatCompletions & Completions] Support Return Hidden State (#7329)
Signed-off-by:
keru
<
rukeyang@gmail.com
>
parent
4d8d9b8e
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
184 additions
and
3 deletions
+184
-3
python/sglang/srt/entrypoints/openai/protocol.py
python/sglang/srt/entrypoints/openai/protocol.py
+41
-1
python/sglang/srt/entrypoints/openai/serving_chat.py
python/sglang/srt/entrypoints/openai/serving_chat.py
+33
-0
python/sglang/srt/entrypoints/openai/serving_completions.py
python/sglang/srt/entrypoints/openai/serving_completions.py
+34
-1
python/sglang/srt/entrypoints/openai/utils.py
python/sglang/srt/entrypoints/openai/utils.py
+31
-1
test/srt/openai/test_protocol.py
test/srt/openai/test_protocol.py
+45
-0
No files found.
python/sglang/srt/entrypoints/openai/protocol.py
View file @
5e7fdc79
...
...
@@ -16,7 +16,13 @@
import
time
from
typing
import
Dict
,
List
,
Optional
,
Union
from
pydantic
import
BaseModel
,
Field
,
field_validator
,
model_validator
from
pydantic
import
(
BaseModel
,
Field
,
field_validator
,
model_serializer
,
model_validator
,
)
from
typing_extensions
import
Literal
...
...
@@ -167,6 +173,7 @@ class CompletionRequest(BaseModel):
temperature
:
float
=
1.0
top_p
:
float
=
1.0
user
:
Optional
[
str
]
=
None
return_hidden_states
:
bool
=
False
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k
:
int
=
-
1
...
...
@@ -202,6 +209,14 @@ class CompletionResponseChoice(BaseModel):
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Literal
[
"stop"
,
"length"
,
"content_filter"
,
"abort"
]
matched_stop
:
Union
[
None
,
int
,
str
]
=
None
hidden_states
:
Optional
[
object
]
=
None
@
model_serializer
(
mode
=
"wrap"
)
def
_serialize
(
self
,
handler
):
data
=
handler
(
self
)
if
self
.
hidden_states
is
None
:
data
.
pop
(
"hidden_states"
,
None
)
return
data
class
CompletionResponse
(
BaseModel
):
...
...
@@ -219,6 +234,14 @@ class CompletionResponseStreamChoice(BaseModel):
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
,
"content_filter"
]]
=
None
matched_stop
:
Union
[
None
,
int
,
str
]
=
None
hidden_states
:
Optional
[
object
]
=
None
@
model_serializer
(
mode
=
"wrap"
)
def
_serialize
(
self
,
handler
):
data
=
handler
(
self
)
if
self
.
hidden_states
is
None
:
data
.
pop
(
"hidden_states"
,
None
)
return
data
class
CompletionStreamResponse
(
BaseModel
):
...
...
@@ -376,6 +399,7 @@ class ChatCompletionRequest(BaseModel):
tool_choice
:
Union
[
ToolChoice
,
Literal
[
"auto"
,
"required"
,
"none"
]]
=
Field
(
default
=
"auto"
,
examples
=
[
"none"
]
)
# noqa
return_hidden_states
:
bool
=
False
@
model_validator
(
mode
=
"before"
)
@
classmethod
...
...
@@ -437,6 +461,14 @@ class ChatCompletionResponseChoice(BaseModel):
"stop"
,
"length"
,
"tool_calls"
,
"content_filter"
,
"function_call"
,
"abort"
]
matched_stop
:
Union
[
None
,
int
,
str
]
=
None
hidden_states
:
Optional
[
object
]
=
None
@
model_serializer
(
mode
=
"wrap"
)
def
_serialize
(
self
,
handler
):
data
=
handler
(
self
)
if
self
.
hidden_states
is
None
:
data
.
pop
(
"hidden_states"
,
None
)
return
data
class
ChatCompletionResponse
(
BaseModel
):
...
...
@@ -453,6 +485,14 @@ class DeltaMessage(BaseModel):
content
:
Optional
[
str
]
=
None
reasoning_content
:
Optional
[
str
]
=
None
tool_calls
:
Optional
[
List
[
ToolCall
]]
=
Field
(
default
=
None
,
examples
=
[
None
])
hidden_states
:
Optional
[
object
]
=
None
@
model_serializer
(
mode
=
"wrap"
)
def
_serialize
(
self
,
handler
):
data
=
handler
(
self
)
if
self
.
hidden_states
is
None
:
data
.
pop
(
"hidden_states"
,
None
)
return
data
class
ChatCompletionResponseStreamChoice
(
BaseModel
):
...
...
python/sglang/srt/entrypoints/openai/serving_chat.py
View file @
5e7fdc79
...
...
@@ -30,6 +30,7 @@ from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
from
sglang.srt.entrypoints.openai.utils
import
(
detect_template_content_format
,
process_content_for_template_format
,
process_hidden_states_from_ret
,
to_openai_style_logprobs
,
)
from
sglang.srt.function_call.function_call_parser
import
FunctionCallParser
...
...
@@ -99,6 +100,7 @@ class OpenAIServingChat(OpenAIServingBase):
bootstrap_host
=
request
.
bootstrap_host
,
bootstrap_port
=
request
.
bootstrap_port
,
bootstrap_room
=
request
.
bootstrap_room
,
return_hidden_states
=
request
.
return_hidden_states
,
)
return
adapted_request
,
request
...
...
@@ -402,6 +404,7 @@ class OpenAIServingChat(OpenAIServingBase):
prompt_tokens
=
{}
completion_tokens
=
{}
cached_tokens
=
{}
hidden_states
=
{}
try
:
async
for
content
in
self
.
tokenizer_manager
.
generate_request
(
...
...
@@ -412,6 +415,7 @@ class OpenAIServingChat(OpenAIServingBase):
prompt_tokens
[
index
]
=
content
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
[
index
]
=
content
[
"meta_info"
][
"completion_tokens"
]
cached_tokens
[
index
]
=
content
[
"meta_info"
].
get
(
"cached_tokens"
,
0
)
hidden_states
[
index
]
=
content
[
"meta_info"
].
get
(
"hidden_states"
,
None
)
# Handle logprobs
choice_logprobs
=
None
...
...
@@ -544,6 +548,31 @@ class OpenAIServingChat(OpenAIServingBase):
)
yield
f
"data:
{
finish_reason_chunk
.
model_dump_json
()
}
\n\n
"
# Send hidden states if requested
if
request
.
return_hidden_states
and
hidden_states
:
for
index
,
choice_hidden_states
in
hidden_states
.
items
():
if
choice_hidden_states
:
last_token_hidden_states
=
(
choice_hidden_states
[
-
1
]
if
len
(
choice_hidden_states
)
>
1
else
[]
)
hidden_states_chunk
=
ChatCompletionStreamResponse
(
id
=
content
[
"meta_info"
][
"id"
],
created
=
int
(
time
.
time
()),
choices
=
[
ChatCompletionResponseStreamChoice
(
index
=
index
,
delta
=
DeltaMessage
(
hidden_states
=
last_token_hidden_states
),
finish_reason
=
finish_reason_type
,
)
],
model
=
request
.
model
,
)
yield
f
"data:
{
hidden_states_chunk
.
model_dump_json
()
}
\n\n
"
# Additional usage chunk
if
request
.
stream_options
and
request
.
stream_options
.
include_usage
:
usage
=
UsageProcessor
.
calculate_streaming_usage
(
...
...
@@ -608,6 +637,9 @@ class OpenAIServingChat(OpenAIServingBase):
if
request
.
logprobs
:
choice_logprobs
=
self
.
_process_response_logprobs
(
ret_item
)
# Handle hidden states
hidden_states
=
process_hidden_states_from_ret
(
ret_item
,
request
)
finish_reason
=
ret_item
[
"meta_info"
][
"finish_reason"
]
text
=
ret_item
[
"text"
]
...
...
@@ -654,6 +686,7 @@ class OpenAIServingChat(OpenAIServingBase):
if
finish_reason
and
"matched"
in
finish_reason
else
None
),
hidden_states
=
hidden_states
,
)
choices
.
append
(
choice_data
)
...
...
python/sglang/srt/entrypoints/openai/serving_completions.py
View file @
5e7fdc79
...
...
@@ -19,7 +19,10 @@ from sglang.srt.entrypoints.openai.protocol import (
)
from
sglang.srt.entrypoints.openai.serving_base
import
OpenAIServingBase
from
sglang.srt.entrypoints.openai.usage_processor
import
UsageProcessor
from
sglang.srt.entrypoints.openai.utils
import
to_openai_style_logprobs
from
sglang.srt.entrypoints.openai.utils
import
(
process_hidden_states_from_ret
,
to_openai_style_logprobs
,
)
from
sglang.srt.managers.io_struct
import
GenerateReqInput
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -76,6 +79,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
bootstrap_host
=
request
.
bootstrap_host
,
bootstrap_port
=
request
.
bootstrap_port
,
bootstrap_room
=
request
.
bootstrap_room
,
return_hidden_states
=
request
.
return_hidden_states
,
)
return
adapted_request
,
request
...
...
@@ -188,6 +192,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
delta
=
text
[
len
(
stream_buffer
)
:]
stream_buffers
[
index
]
=
stream_buffer
+
delta
finish_reason
=
content
[
"meta_info"
][
"finish_reason"
]
hidden_states
=
content
[
"meta_info"
].
get
(
"hidden_states"
,
None
)
choice_data
=
CompletionResponseStreamChoice
(
index
=
index
,
...
...
@@ -210,6 +215,30 @@ class OpenAIServingCompletion(OpenAIServingBase):
yield
f
"data:
{
chunk
.
model_dump_json
()
}
\n\n
"
if
request
.
return_hidden_states
and
hidden_states
:
for
index
,
choice_hidden_states
in
hidden_states
.
items
():
if
choice_hidden_states
:
last_token_hidden_states
=
(
choice_hidden_states
[
-
1
]
if
len
(
choice_hidden_states
)
>
1
else
[]
)
hidden_states_chunk
=
CompletionStreamResponse
(
id
=
content
[
"meta_info"
][
"id"
],
created
=
created
,
object
=
"text_completion"
,
choices
=
[
CompletionResponseStreamChoice
(
index
=
index
,
text
=
""
,
hidden_states
=
last_token_hidden_states
,
finish_reason
=
None
,
)
],
model
=
request
.
model
,
)
yield
f
"data:
{
hidden_states_chunk
.
model_dump_json
()
}
\n\n
"
# Handle final usage chunk
if
request
.
stream_options
and
request
.
stream_options
.
include_usage
:
usage
=
UsageProcessor
.
calculate_streaming_usage
(
...
...
@@ -304,6 +333,9 @@ class OpenAIServingCompletion(OpenAIServingBase):
output_top_logprobs
=
ret_item
[
"meta_info"
][
"output_top_logprobs"
],
)
# Handle hidden states
hidden_states
=
process_hidden_states_from_ret
(
ret_item
,
request
)
finish_reason
=
ret_item
[
"meta_info"
][
"finish_reason"
]
choice_data
=
CompletionResponseChoice
(
...
...
@@ -316,6 +348,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
if
finish_reason
and
"matched"
in
finish_reason
else
None
),
hidden_states
=
hidden_states
,
)
choices
.
append
(
choice_data
)
...
...
python/sglang/srt/entrypoints/openai/utils.py
View file @
5e7fdc79
import
logging
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
jinja2.nodes
import
transformers.utils.chat_template_utils
as
hf_chat_utils
from
sglang.srt.entrypoints.openai.protocol
import
LogProbs
from
sglang.srt.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
,
LogProbs
,
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -205,3 +210,28 @@ def to_openai_style_logprobs(
append_top_logprobs
(
output_top_logprobs
)
return
ret_logprobs
def
process_hidden_states_from_ret
(
ret_item
:
Dict
[
str
,
Any
],
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
,
],
)
->
Optional
[
List
]:
"""Process hidden states from a ret item in non-streaming response.
Args:
ret_item: Response item containing meta_info
request: The original request object
Returns:
Processed hidden states for the last token, or None
"""
if
not
request
.
return_hidden_states
:
return
None
hidden_states
=
ret_item
[
"meta_info"
].
get
(
"hidden_states"
,
None
)
if
hidden_states
is
not
None
:
hidden_states
=
hidden_states
[
-
1
]
if
len
(
hidden_states
)
>
1
else
[]
return
hidden_states
test/srt/openai/test_protocol.py
View file @
5e7fdc79
...
...
@@ -632,6 +632,51 @@ class TestStreamingModels(unittest.TestCase):
self
.
assertEqual
(
response
.
choices
[
0
].
delta
.
content
,
"Hello"
)
class
TestModelSerialization
(
unittest
.
TestCase
):
"""Test model serialization with hidden states"""
def
test_hidden_states_excluded_when_none
(
self
):
"""Test that None hidden_states are excluded with exclude_none=True"""
choice
=
ChatCompletionResponseChoice
(
index
=
0
,
message
=
ChatMessage
(
role
=
"assistant"
,
content
=
"Hello"
),
finish_reason
=
"stop"
,
hidden_states
=
None
,
)
response
=
ChatCompletionResponse
(
id
=
"test-id"
,
model
=
"test-model"
,
choices
=
[
choice
],
usage
=
UsageInfo
(
prompt_tokens
=
5
,
completion_tokens
=
1
,
total_tokens
=
6
),
)
# Test exclude_none serialization (should exclude None hidden_states)
data
=
response
.
model_dump
(
exclude_none
=
True
)
self
.
assertNotIn
(
"hidden_states"
,
data
[
"choices"
][
0
])
def
test_hidden_states_included_when_not_none
(
self
):
"""Test that non-None hidden_states are included"""
choice
=
ChatCompletionResponseChoice
(
index
=
0
,
message
=
ChatMessage
(
role
=
"assistant"
,
content
=
"Hello"
),
finish_reason
=
"stop"
,
hidden_states
=
[
0.1
,
0.2
,
0.3
],
)
response
=
ChatCompletionResponse
(
id
=
"test-id"
,
model
=
"test-model"
,
choices
=
[
choice
],
usage
=
UsageInfo
(
prompt_tokens
=
5
,
completion_tokens
=
1
,
total_tokens
=
6
),
)
# Test exclude_none serialization (should include non-None hidden_states)
data
=
response
.
model_dump
(
exclude_none
=
True
)
self
.
assertIn
(
"hidden_states"
,
data
[
"choices"
][
0
])
self
.
assertEqual
(
data
[
"choices"
][
0
][
"hidden_states"
],
[
0.1
,
0.2
,
0.3
])
class
TestValidationEdgeCases
(
unittest
.
TestCase
):
"""Test edge cases and validation scenarios"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment