Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
72676cd6
Unverified
Commit
72676cd6
authored
Jun 21, 2025
by
Chang Su
Committed by
GitHub
Jun 21, 2025
Browse files
feat(oai refactor): Replace `openai_api` with `entrypoints/openai` (#7351)
Co-authored-by:
Jin Pan
<
jpan236@wisc.edu
>
parent
02bf31ef
Changes
43
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
308 additions
and
3577 deletions
+308
-3577
python/sglang/srt/function_call/llama32_detector.py
python/sglang/srt/function_call/llama32_detector.py
+1
-1
python/sglang/srt/function_call/mistral_detector.py
python/sglang/srt/function_call/mistral_detector.py
+1
-1
python/sglang/srt/function_call/pythonic_detector.py
python/sglang/srt/function_call/pythonic_detector.py
+1
-1
python/sglang/srt/function_call/qwen25_detector.py
python/sglang/srt/function_call/qwen25_detector.py
+1
-1
python/sglang/srt/jinja_template_utils.py
python/sglang/srt/jinja_template_utils.py
+6
-5
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+0
-6
python/sglang/srt/managers/template_manager.py
python/sglang/srt/managers/template_manager.py
+226
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+1
-6
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+0
-2148
python/sglang/srt/openai_api/protocol.py
python/sglang/srt/openai_api/protocol.py
+0
-551
python/sglang/srt/reasoning_parser.py
python/sglang/srt/reasoning_parser.py
+21
-11
test/srt/openai/conftest.py
test/srt/openai/conftest.py
+0
-87
test/srt/openai/test_protocol.py
test/srt/openai/test_protocol.py
+0
-451
test/srt/openai/test_server.py
test/srt/openai/test_server.py
+0
-52
test/srt/openai/test_serving_chat.py
test/srt/openai/test_serving_chat.py
+11
-91
test/srt/openai/test_serving_completions.py
test/srt/openai/test_serving_completions.py
+14
-15
test/srt/openai/test_serving_embedding.py
test/srt/openai/test_serving_embedding.py
+13
-137
test/srt/run_suite.py
test/srt/run_suite.py
+5
-6
test/srt/test_function_call_parser.py
test/srt/test_function_call_parser.py
+1
-1
test/srt/test_jinja_template_utils.py
test/srt/test_jinja_template_utils.py
+6
-6
No files found.
python/sglang/srt/function_call/llama32_detector.py
View file @
72676cd6
...
...
@@ -2,6 +2,7 @@ import json
import
logging
from
typing
import
List
from
sglang.srt.entrypoints.openai.protocol
import
Tool
from
sglang.srt.function_call.base_format_detector
import
BaseFormatDetector
from
sglang.srt.function_call.core_types
import
(
StreamingParseResult
,
...
...
@@ -9,7 +10,6 @@ from sglang.srt.function_call.core_types import (
_GetInfoFunc
,
)
from
sglang.srt.function_call.ebnf_composer
import
EBNFComposer
from
sglang.srt.openai_api.protocol
import
Tool
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/function_call/mistral_detector.py
View file @
72676cd6
...
...
@@ -3,6 +3,7 @@ import logging
import
re
from
typing
import
List
from
sglang.srt.entrypoints.openai.protocol
import
Tool
from
sglang.srt.function_call.base_format_detector
import
BaseFormatDetector
from
sglang.srt.function_call.core_types
import
(
StreamingParseResult
,
...
...
@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
_GetInfoFunc
,
)
from
sglang.srt.function_call.ebnf_composer
import
EBNFComposer
from
sglang.srt.openai_api.protocol
import
Tool
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/function_call/pythonic_detector.py
View file @
72676cd6
...
...
@@ -4,6 +4,7 @@ import logging
import
re
from
typing
import
List
,
Optional
from
sglang.srt.entrypoints.openai.protocol
import
Tool
from
sglang.srt.function_call.base_format_detector
import
BaseFormatDetector
from
sglang.srt.function_call.core_types
import
(
StreamingParseResult
,
...
...
@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
_GetInfoFunc
,
)
from
sglang.srt.function_call.ebnf_composer
import
EBNFComposer
from
sglang.srt.openai_api.protocol
import
Tool
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/function_call/qwen25_detector.py
View file @
72676cd6
...
...
@@ -3,6 +3,7 @@ import logging
import
re
from
typing
import
List
from
sglang.srt.entrypoints.openai.protocol
import
Tool
from
sglang.srt.function_call.base_format_detector
import
BaseFormatDetector
from
sglang.srt.function_call.core_types
import
(
StreamingParseResult
,
...
...
@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
_GetInfoFunc
,
)
from
sglang.srt.function_call.ebnf_composer
import
EBNFComposer
from
sglang.srt.openai_api.protocol
import
Tool
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/
openai_api/
utils.py
→
python/sglang/srt/
jinja_template_
utils.py
View file @
72676cd6
"""
Utility functions for OpenAI API adapter.
"""Template utilities for Jinja template processing.
This module provides utilities for analyzing and processing Jinja chat templates,
including content format detection and message processing.
"""
import
logging
from
typing
import
Dict
,
List
import
jinja2
.nodes
import
jinja2
import
transformers.utils.chat_template_utils
as
hf_chat_utils
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -75,7 +76,7 @@ def _try_extract_ast(chat_template: str):
return
None
def
detect_template_content_format
(
chat_template
:
str
)
->
str
:
def
detect_
jinja_
template_content_format
(
chat_template
:
str
)
->
str
:
"""
Detect whether a chat template expects 'string' or 'openai' content format.
...
...
python/sglang/srt/managers/io_struct.py
View file @
72676cd6
...
...
@@ -864,12 +864,6 @@ class SetInternalStateReq:
server_args
:
Dict
[
str
,
Any
]
@
dataclass
class
V1RerankReqInput
:
query
:
str
documents
:
List
[
str
]
@
dataclass
class
SetInternalStateReqOutput
:
updated
:
bool
...
...
python/sglang/srt/managers/template_manager.py
0 → 100644
View file @
72676cd6
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Centralized template management for chat templates and completion templates.
This module provides a unified interface for managing both chat conversation templates
and code completion templates, eliminating global state and improving modularity.
"""
import
json
import
logging
import
os
from
typing
import
Optional
from
sglang.srt.code_completion_parser
import
(
CompletionTemplate
,
FimPosition
,
completion_template_exists
,
register_completion_template
,
)
from
sglang.srt.conversation
import
(
Conversation
,
SeparatorStyle
,
chat_template_exists
,
get_conv_template_by_model_path
,
register_conv_template
,
)
from
sglang.srt.jinja_template_utils
import
detect_jinja_template_content_format
logger
=
logging
.
getLogger
(
__name__
)
class
TemplateManager
:
"""
Centralized manager for chat and completion templates.
This class encapsulates all template-related state and operations,
eliminating the need for global variables and providing a clean
interface for template management.
"""
def
__init__
(
self
):
self
.
_chat_template_name
:
Optional
[
str
]
=
None
self
.
_completion_template_name
:
Optional
[
str
]
=
None
self
.
_jinja_template_content_format
:
Optional
[
str
]
=
None
@
property
def
chat_template_name
(
self
)
->
Optional
[
str
]:
"""Get the current chat template name."""
return
self
.
_chat_template_name
@
property
def
completion_template_name
(
self
)
->
Optional
[
str
]:
"""Get the current completion template name."""
return
self
.
_completion_template_name
@
property
def
jinja_template_content_format
(
self
)
->
Optional
[
str
]:
"""Get the detected template content format ('string' or 'openai' or None)."""
return
self
.
_jinja_template_content_format
def
load_chat_template
(
self
,
tokenizer_manager
,
chat_template_arg
:
str
,
model_path
:
str
)
->
None
:
"""
Load a chat template from various sources.
Args:
tokenizer_manager: The tokenizer manager instance
chat_template_arg: Template name or file path
model_path: Path to the model
"""
logger
.
info
(
f
"Loading chat template:
{
chat_template_arg
}
"
)
if
not
chat_template_exists
(
chat_template_arg
):
if
not
os
.
path
.
exists
(
chat_template_arg
):
raise
RuntimeError
(
f
"Chat template
{
chat_template_arg
}
is not a built-in template name "
"or a valid chat template file path."
)
if
chat_template_arg
.
endswith
(
".jinja"
):
self
.
_load_jinja_template
(
tokenizer_manager
,
chat_template_arg
)
else
:
self
.
_load_json_chat_template
(
chat_template_arg
)
else
:
self
.
_chat_template_name
=
chat_template_arg
def
guess_chat_template_from_model_path
(
self
,
model_path
:
str
)
->
None
:
"""
Infer chat template name from model path.
Args:
model_path: Path to the model
"""
template_name
=
get_conv_template_by_model_path
(
model_path
)
if
template_name
is
not
None
:
logger
.
info
(
f
"Inferred chat template from model path:
{
template_name
}
"
)
self
.
_chat_template_name
=
template_name
def
load_completion_template
(
self
,
completion_template_arg
:
str
)
->
None
:
"""
Load completion template for code completion.
Args:
completion_template_arg: Template name or file path
"""
logger
.
info
(
f
"Loading completion template:
{
completion_template_arg
}
"
)
if
not
completion_template_exists
(
completion_template_arg
):
if
not
os
.
path
.
exists
(
completion_template_arg
):
raise
RuntimeError
(
f
"Completion template
{
completion_template_arg
}
is not a built-in template name "
"or a valid completion template file path."
)
self
.
_load_json_completion_template
(
completion_template_arg
)
else
:
self
.
_completion_template_name
=
completion_template_arg
def
initialize_templates
(
self
,
tokenizer_manager
,
model_path
:
str
,
chat_template
:
Optional
[
str
]
=
None
,
completion_template
:
Optional
[
str
]
=
None
,
)
->
None
:
"""
Initialize all templates based on provided configuration.
Args:
tokenizer_manager: The tokenizer manager instance
model_path: Path to the model
chat_template: Optional chat template name/path
completion_template: Optional completion template name/path
"""
# Load chat template
if
chat_template
:
self
.
load_chat_template
(
tokenizer_manager
,
chat_template
,
model_path
)
else
:
self
.
guess_chat_template_from_model_path
(
model_path
)
# Load completion template
if
completion_template
:
self
.
load_completion_template
(
completion_template
)
def
_load_jinja_template
(
self
,
tokenizer_manager
,
template_path
:
str
)
->
None
:
"""Load a Jinja template file."""
with
open
(
template_path
,
"r"
)
as
f
:
chat_template
=
""
.
join
(
f
.
readlines
()).
strip
(
"
\n
"
)
tokenizer_manager
.
tokenizer
.
chat_template
=
chat_template
.
replace
(
"
\\
n"
,
"
\n
"
)
self
.
_chat_template_name
=
None
# Detect content format from the loaded template
self
.
_jinja_template_content_format
=
detect_jinja_template_content_format
(
chat_template
)
logger
.
info
(
f
"Detected chat template content format:
{
self
.
_jinja_template_content_format
}
"
)
def
_load_json_chat_template
(
self
,
template_path
:
str
)
->
None
:
"""Load a JSON chat template file."""
assert
template_path
.
endswith
(
".json"
),
"unrecognized format of chat template file"
with
open
(
template_path
,
"r"
)
as
filep
:
template
=
json
.
load
(
filep
)
try
:
sep_style
=
SeparatorStyle
[
template
[
"sep_style"
]]
except
KeyError
:
raise
ValueError
(
f
"Unknown separator style:
{
template
[
'sep_style'
]
}
"
)
from
None
register_conv_template
(
Conversation
(
name
=
template
[
"name"
],
system_template
=
template
[
"system"
]
+
"
\n
{system_message}"
,
system_message
=
template
.
get
(
"system_message"
,
""
),
roles
=
(
template
[
"user"
],
template
[
"assistant"
]),
sep_style
=
sep_style
,
sep
=
template
.
get
(
"sep"
,
"
\n
"
),
stop_str
=
template
[
"stop_str"
],
),
override
=
True
,
)
self
.
_chat_template_name
=
template
[
"name"
]
def
_load_json_completion_template
(
self
,
template_path
:
str
)
->
None
:
"""Load a JSON completion template file."""
assert
template_path
.
endswith
(
".json"
),
"unrecognized format of completion template file"
with
open
(
template_path
,
"r"
)
as
filep
:
template
=
json
.
load
(
filep
)
try
:
fim_position
=
FimPosition
[
template
[
"fim_position"
]]
except
KeyError
:
raise
ValueError
(
f
"Unknown fim position:
{
template
[
'fim_position'
]
}
"
)
from
None
register_completion_template
(
CompletionTemplate
(
name
=
template
[
"name"
],
fim_begin_token
=
template
[
"fim_begin_token"
],
fim_middle_token
=
template
[
"fim_middle_token"
],
fim_end_token
=
template
[
"fim_end_token"
],
fim_position
=
fim_position
,
),
override
=
True
,
)
self
.
_completion_template_name
=
template
[
"name"
]
python/sglang/srt/managers/tokenizer_manager.py
View file @
72676cd6
...
...
@@ -1058,12 +1058,7 @@ class TokenizerManager:
"lora_path"
,
]
)
out_skip_names
=
set
(
[
"text"
,
"output_ids"
,
]
)
out_skip_names
=
set
([
"text"
,
"output_ids"
,
"embedding"
])
elif
self
.
log_requests_level
==
1
:
max_length
=
2048
elif
self
.
log_requests_level
==
2
:
...
...
python/sglang/srt/openai_api/adapter.py
deleted
100644 → 0
View file @
02bf31ef
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Conversion between OpenAI APIs and native SRT APIs"""
import
asyncio
import
base64
import
json
import
logging
import
os
import
time
import
uuid
from
http
import
HTTPStatus
from
typing
import
Dict
,
List
from
fastapi
import
HTTPException
,
Request
,
UploadFile
from
fastapi.responses
import
ORJSONResponse
,
StreamingResponse
from
pydantic
import
ValidationError
from
sglang.srt.code_completion_parser
import
(
generate_completion_prompt_from_request
,
is_completion_template_defined
,
)
from
sglang.srt.conversation
import
(
Conversation
,
SeparatorStyle
,
chat_template_exists
,
generate_chat_conv
,
generate_embedding_convs
,
get_conv_template_by_model_path
,
register_conv_template
,
)
from
sglang.srt.function_call.function_call_parser
import
FunctionCallParser
from
sglang.srt.managers.io_struct
import
(
EmbeddingReqInput
,
GenerateReqInput
,
V1RerankReqInput
,
)
from
sglang.srt.openai_api.protocol
import
(
BatchRequest
,
BatchResponse
,
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionStreamResponse
,
ChatCompletionTokenLogprob
,
ChatMessage
,
ChoiceLogprobs
,
CompletionRequest
,
CompletionResponse
,
CompletionResponseChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
DeltaMessage
,
EmbeddingObject
,
EmbeddingRequest
,
EmbeddingResponse
,
ErrorResponse
,
FileDeleteResponse
,
FileRequest
,
FileResponse
,
FunctionResponse
,
LogProbs
,
MultimodalEmbeddingInput
,
RerankResponse
,
ScoringRequest
,
ScoringResponse
,
ToolCall
,
TopLogprob
,
UsageInfo
,
)
from
sglang.srt.openai_api.utils
import
(
detect_template_content_format
,
process_content_for_template_format
,
)
from
sglang.srt.reasoning_parser
import
ReasoningParser
from
sglang.utils
import
convert_json_schema_to_str
,
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
chat_template_name
=
None
# Global cache for template content format detection (one model/template per instance)
# NOTE: A better approach would be to initialize the chat template format when the endpoint is created
_cached_chat_template
=
None
_cached_template_format
=
None
class
FileMetadata
:
def
__init__
(
self
,
filename
:
str
,
purpose
:
str
):
self
.
filename
=
filename
self
.
purpose
=
purpose
# In-memory storage for batch jobs and files
batch_storage
:
Dict
[
str
,
BatchResponse
]
=
{}
file_id_request
:
Dict
[
str
,
FileMetadata
]
=
{}
file_id_response
:
Dict
[
str
,
FileResponse
]
=
{}
# map file id to file path in SGLang backend
file_id_storage
:
Dict
[
str
,
str
]
=
{}
# backend storage directory
storage_dir
=
None
def
create_error_response
(
message
:
str
,
err_type
:
str
=
"BadRequestError"
,
status_code
:
HTTPStatus
=
HTTPStatus
.
BAD_REQUEST
,
):
error
=
ErrorResponse
(
message
=
message
,
type
=
err_type
,
code
=
status_code
.
value
)
return
ORJSONResponse
(
content
=
error
.
model_dump
(),
status_code
=
error
.
code
)
def
create_streaming_error_response
(
message
:
str
,
err_type
:
str
=
"BadRequestError"
,
status_code
:
HTTPStatus
=
HTTPStatus
.
BAD_REQUEST
,
)
->
str
:
error
=
ErrorResponse
(
message
=
message
,
type
=
err_type
,
code
=
status_code
.
value
)
json_str
=
json
.
dumps
({
"error"
:
error
.
model_dump
()})
return
json_str
def
load_chat_template_for_openai_api
(
tokenizer_manager
,
chat_template_arg
,
model_path
):
global
chat_template_name
logger
.
info
(
f
"Use chat template for the OpenAI-compatible API server:
{
chat_template_arg
}
"
)
if
not
chat_template_exists
(
chat_template_arg
):
if
not
os
.
path
.
exists
(
chat_template_arg
):
raise
RuntimeError
(
f
"Chat template
{
chat_template_arg
}
is not a built-in template name "
"or a valid chat template file path."
)
if
chat_template_arg
.
endswith
(
".jinja"
):
with
open
(
chat_template_arg
,
"r"
)
as
f
:
chat_template
=
""
.
join
(
f
.
readlines
()).
strip
(
"
\n
"
)
tokenizer_manager
.
tokenizer
.
chat_template
=
chat_template
.
replace
(
"
\\
n"
,
"
\n
"
)
chat_template_name
=
None
else
:
assert
chat_template_arg
.
endswith
(
".json"
),
"unrecognized format of chat template file"
with
open
(
chat_template_arg
,
"r"
)
as
filep
:
template
=
json
.
load
(
filep
)
try
:
sep_style
=
SeparatorStyle
[
template
[
"sep_style"
]]
except
KeyError
:
raise
ValueError
(
f
"Unknown separator style:
{
template
[
'sep_style'
]
}
"
)
from
None
register_conv_template
(
Conversation
(
name
=
template
[
"name"
],
system_template
=
template
[
"system"
]
+
"
\n
{system_message}"
,
system_message
=
template
.
get
(
"system_message"
,
""
),
roles
=
(
template
[
"user"
],
template
[
"assistant"
]),
sep_style
=
sep_style
,
sep
=
template
.
get
(
"sep"
,
"
\n
"
),
stop_str
=
template
[
"stop_str"
],
),
override
=
True
,
)
chat_template_name
=
template
[
"name"
]
else
:
chat_template_name
=
chat_template_arg
def
guess_chat_template_name_from_model_path
(
model_path
):
global
chat_template_name
chat_template_name
=
get_conv_template_by_model_path
(
model_path
)
if
chat_template_name
is
not
None
:
logger
.
info
(
f
"Infer the chat template name from the model path and obtain the result:
{
chat_template_name
}
."
)
def
_validate_prompt
(
prompt
:
str
):
"""Validate that the prompt is not empty or whitespace only."""
is_invalid
=
False
# Check for empty/whitespace string
if
isinstance
(
prompt
,
str
):
is_invalid
=
not
prompt
.
strip
()
# Check for various invalid list cases: [], [""], [" "], [[]]
elif
isinstance
(
prompt
,
list
):
is_invalid
=
not
prompt
or
(
len
(
prompt
)
==
1
and
(
(
isinstance
(
prompt
[
0
],
str
)
and
not
prompt
[
0
].
strip
())
or
(
isinstance
(
prompt
[
0
],
list
)
and
not
prompt
[
0
])
)
)
if
is_invalid
:
raise
HTTPException
(
status_code
=
400
,
detail
=
"Input cannot be empty or contain only whitespace."
,
)
return
prompt
async
def
v1_files_create
(
file
:
UploadFile
,
purpose
:
str
,
file_storage_path
:
str
=
None
):
try
:
global
storage_dir
if
file_storage_path
:
storage_dir
=
file_storage_path
# Read the file content
file_content
=
await
file
.
read
()
# Create an instance of RequestBody
request_body
=
FileRequest
(
file
=
file_content
,
purpose
=
purpose
)
# Save the file to the sglang_oai_storage directory
os
.
makedirs
(
storage_dir
,
exist_ok
=
True
)
file_id
=
f
"backend_input_file-
{
uuid
.
uuid4
()
}
"
filename
=
f
"
{
file_id
}
.jsonl"
file_path
=
os
.
path
.
join
(
storage_dir
,
filename
)
with
open
(
file_path
,
"wb"
)
as
f
:
f
.
write
(
request_body
.
file
)
# add info to global file map
file_id_request
[
file_id
]
=
FileMetadata
(
filename
=
file
.
filename
,
purpose
=
purpose
)
file_id_storage
[
file_id
]
=
file_path
# Return the response in the required format
response
=
FileResponse
(
id
=
file_id
,
bytes
=
len
(
request_body
.
file
),
created_at
=
int
(
time
.
time
()),
filename
=
file
.
filename
,
purpose
=
request_body
.
purpose
,
)
file_id_response
[
file_id
]
=
response
return
response
except
ValidationError
as
e
:
return
{
"error"
:
"Invalid input"
,
"details"
:
e
.
errors
()}
async
def
v1_delete_file
(
file_id
:
str
):
# Retrieve the file job from the in-memory storage
file_response
=
file_id_response
.
get
(
file_id
)
if
file_response
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"File not found"
)
file_path
=
file_id_storage
.
get
(
file_id
)
if
file_path
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"File not found"
)
os
.
remove
(
file_path
)
del
file_id_response
[
file_id
]
del
file_id_storage
[
file_id
]
return
FileDeleteResponse
(
id
=
file_id
,
deleted
=
True
)
async
def
v1_batches
(
tokenizer_manager
,
raw_request
:
Request
):
try
:
body
=
await
raw_request
.
json
()
batch_request
=
BatchRequest
(
**
body
)
batch_id
=
f
"batch_
{
uuid
.
uuid4
()
}
"
# Create an instance of BatchResponse
batch_response
=
BatchResponse
(
id
=
batch_id
,
endpoint
=
batch_request
.
endpoint
,
input_file_id
=
batch_request
.
input_file_id
,
completion_window
=
batch_request
.
completion_window
,
created_at
=
int
(
time
.
time
()),
metadata
=
batch_request
.
metadata
,
)
batch_storage
[
batch_id
]
=
batch_response
# Start processing the batch asynchronously
asyncio
.
create_task
(
process_batch
(
tokenizer_manager
,
batch_id
,
batch_request
))
# Return the initial batch_response
return
batch_response
except
ValidationError
as
e
:
return
{
"error"
:
"Invalid input"
,
"details"
:
e
.
errors
()}
except
Exception
as
e
:
return
{
"error"
:
str
(
e
)}
async
def
process_batch
(
tokenizer_manager
,
batch_id
:
str
,
batch_request
:
BatchRequest
):
try
:
# Update the batch status to "in_progress"
batch_storage
[
batch_id
].
status
=
"in_progress"
batch_storage
[
batch_id
].
in_progress_at
=
int
(
time
.
time
())
# Retrieve the input file content
input_file_request
=
file_id_request
.
get
(
batch_request
.
input_file_id
)
if
not
input_file_request
:
raise
ValueError
(
"Input file not found"
)
# Parse the JSONL file and process each request
input_file_path
=
file_id_storage
.
get
(
batch_request
.
input_file_id
)
with
open
(
input_file_path
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
lines
=
f
.
readlines
()
total_requests
=
len
(
lines
)
completed_requests
=
0
failed_requests
=
0
all_ret
=
[]
end_point
=
batch_storage
[
batch_id
].
endpoint
file_request_list
=
[]
all_requests
=
[]
request_ids
=
[]
for
line_id
,
line
in
enumerate
(
lines
):
request_data
=
json
.
loads
(
line
)
file_request_list
.
append
(
request_data
)
body
=
request_data
[
"body"
]
request_ids
.
append
(
f
"
{
batch_id
}
-req_
{
line_id
}
"
)
# Although streaming is supported for standalone completions, it is not supported in
# batch mode (multiple completions in single request).
if
body
.
get
(
"stream"
,
False
):
raise
ValueError
(
"Streaming requests are not supported in batch mode"
)
if
end_point
==
"/v1/chat/completions"
:
all_requests
.
append
(
ChatCompletionRequest
(
**
body
))
elif
end_point
==
"/v1/completions"
:
all_requests
.
append
(
CompletionRequest
(
**
body
))
if
end_point
==
"/v1/chat/completions"
:
adapted_request
,
request
=
v1_chat_generate_request
(
all_requests
,
tokenizer_manager
,
request_ids
=
request_ids
)
elif
end_point
==
"/v1/completions"
:
adapted_request
,
request
=
v1_generate_request
(
all_requests
,
request_ids
=
request_ids
)
try
:
created
=
int
(
time
.
time
())
ret
=
await
tokenizer_manager
.
generate_request
(
adapted_request
).
__anext__
()
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
if
end_point
==
"/v1/chat/completions"
:
responses
=
v1_chat_generate_response
(
request
,
ret
,
created
,
to_file
=
True
,
cache_report
=
tokenizer_manager
.
server_args
.
enable_cache_report
,
tool_call_parser
=
tokenizer_manager
.
server_args
.
tool_call_parser
,
)
else
:
responses
=
v1_generate_response
(
request
,
ret
,
tokenizer_manager
,
created
,
to_file
=
True
,
cache_report
=
tokenizer_manager
.
server_args
.
enable_cache_report
,
)
except
Exception
as
e
:
logger
.
error
(
f
"error:
{
get_exception_traceback
()
}
"
)
responses
=
[]
error_json
=
{
"id"
:
f
"batch_req_
{
uuid
.
uuid4
()
}
"
,
"custom_id"
:
request_data
.
get
(
"custom_id"
),
"response"
:
None
,
"error"
:
{
"message"
:
str
(
e
)},
}
all_ret
.
append
(
error_json
)
failed_requests
+=
len
(
file_request_list
)
for
idx
,
response
in
enumerate
(
responses
):
# the batch_req here can be changed to be named within a batch granularity
response_json
=
{
"id"
:
f
"batch_req_
{
uuid
.
uuid4
()
}
"
,
"custom_id"
:
file_request_list
[
idx
].
get
(
"custom_id"
),
"response"
:
response
,
"error"
:
None
,
}
all_ret
.
append
(
response_json
)
completed_requests
+=
1
# Write results to a new file
output_file_id
=
f
"backend_result_file-
{
uuid
.
uuid4
()
}
"
global
storage_dir
output_file_path
=
os
.
path
.
join
(
storage_dir
,
f
"
{
output_file_id
}
.jsonl"
)
with
open
(
output_file_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
for
ret
in
all_ret
:
f
.
write
(
json
.
dumps
(
ret
)
+
"
\n
"
)
# Update batch response with output file information
retrieve_batch
=
batch_storage
[
batch_id
]
retrieve_batch
.
output_file_id
=
output_file_id
file_id_storage
[
output_file_id
]
=
output_file_path
file_id_response
[
output_file_id
]
=
FileResponse
(
id
=
output_file_id
,
bytes
=
os
.
path
.
getsize
(
output_file_path
),
created_at
=
int
(
time
.
time
()),
filename
=
f
"
{
output_file_id
}
.jsonl"
,
purpose
=
"batch_result"
,
)
# Update batch status to "completed"
retrieve_batch
.
status
=
"completed"
retrieve_batch
.
completed_at
=
int
(
time
.
time
())
retrieve_batch
.
request_counts
=
{
"total"
:
total_requests
,
"completed"
:
completed_requests
,
"failed"
:
failed_requests
,
}
except
Exception
as
e
:
logger
.
error
(
f
"error:
{
e
}
"
)
# Update batch status to "failed"
retrieve_batch
=
batch_storage
[
batch_id
]
retrieve_batch
.
status
=
"failed"
retrieve_batch
.
failed_at
=
int
(
time
.
time
())
retrieve_batch
.
errors
=
{
"message"
:
str
(
e
)}
async
def
v1_retrieve_batch
(
batch_id
:
str
):
# Retrieve the batch job from the in-memory storage
batch_response
=
batch_storage
.
get
(
batch_id
)
if
batch_response
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"Batch not found"
)
return
batch_response
async
def
v1_cancel_batch
(
tokenizer_manager
,
batch_id
:
str
):
# Retrieve the batch job from the in-memory storage
batch_response
=
batch_storage
.
get
(
batch_id
)
if
batch_response
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"Batch not found"
)
# Only do cancal when status is "validating" or "in_progress"
if
batch_response
.
status
in
[
"validating"
,
"in_progress"
]:
# Start cancelling the batch asynchronously
asyncio
.
create_task
(
cancel_batch
(
tokenizer_manager
=
tokenizer_manager
,
batch_id
=
batch_id
,
input_file_id
=
batch_response
.
input_file_id
,
)
)
# Update batch status to "cancelling"
batch_response
.
status
=
"cancelling"
return
batch_response
else
:
raise
HTTPException
(
status_code
=
500
,
detail
=
f
"Current status is
{
batch_response
.
status
}
, no need to cancel"
,
)
async
def
cancel_batch
(
tokenizer_manager
,
batch_id
:
str
,
input_file_id
:
str
):
try
:
# Update the batch status to "cancelling"
batch_storage
[
batch_id
].
status
=
"cancelling"
# Retrieve the input file content
input_file_request
=
file_id_request
.
get
(
input_file_id
)
if
not
input_file_request
:
raise
ValueError
(
"Input file not found"
)
# Parse the JSONL file and process each request
input_file_path
=
file_id_storage
.
get
(
input_file_id
)
with
open
(
input_file_path
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
lines
=
f
.
readlines
()
# Cancel requests by request_ids
for
line_id
in
range
(
len
(
lines
)):
rid
=
f
"
{
batch_id
}
-req_
{
line_id
}
"
tokenizer_manager
.
abort_request
(
rid
=
rid
)
retrieve_batch
=
batch_storage
[
batch_id
]
retrieve_batch
.
status
=
"cancelled"
except
Exception
as
e
:
logger
.
error
(
"error in SGLang:"
,
e
)
# Update batch status to "failed"
retrieve_batch
=
batch_storage
[
batch_id
]
retrieve_batch
.
status
=
"failed"
retrieve_batch
.
failed_at
=
int
(
time
.
time
())
retrieve_batch
.
errors
=
{
"message"
:
str
(
e
)}
async
def
v1_retrieve_file
(
file_id
:
str
):
# Retrieve the batch job from the in-memory storage
file_response
=
file_id_response
.
get
(
file_id
)
if
file_response
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"File not found"
)
return
file_response
async
def
v1_retrieve_file_content
(
file_id
:
str
):
file_pth
=
file_id_storage
.
get
(
file_id
)
if
not
file_pth
or
not
os
.
path
.
exists
(
file_pth
):
raise
HTTPException
(
status_code
=
404
,
detail
=
"File not found"
)
def
iter_file
():
with
open
(
file_pth
,
mode
=
"rb"
)
as
file_like
:
yield
from
file_like
return
StreamingResponse
(
iter_file
(),
media_type
=
"application/octet-stream"
)
def
v1_generate_request
(
all_requests
:
List
[
CompletionRequest
],
request_ids
:
List
[
str
]
=
None
):
if
len
(
all_requests
)
>
1
:
first_prompt_type
=
type
(
all_requests
[
0
].
prompt
)
for
request
in
all_requests
:
assert
(
type
(
request
.
prompt
)
is
first_prompt_type
),
"All prompts must be of the same type in file input settings"
if
request
.
n
>
1
:
raise
ValueError
(
"Parallel sampling is not supported for completions from files"
)
prompts
=
[]
sampling_params_list
=
[]
return_logprobs
=
[]
logprob_start_lens
=
[]
top_logprobs_nums
=
[]
lora_paths
=
[]
return_hidden_states
=
[]
for
request
in
all_requests
:
# NOTE: with openai API, the prompt's logprobs are always not computed
if
request
.
echo
and
request
.
logprobs
:
logger
.
warning
(
"Echo is not compatible with logprobs. "
"To compute logprobs of input prompt, please use the native /generate API."
)
prompt
=
request
.
prompt
if
is_completion_template_defined
():
prompt
=
generate_completion_prompt_from_request
(
request
)
prompts
.
append
(
prompt
)
lora_paths
.
append
(
request
.
lora_path
)
if
request
.
echo
and
request
.
logprobs
:
current_logprob_start_len
=
0
else
:
current_logprob_start_len
=
-
1
sampling_params_list
.
append
(
{
"temperature"
:
request
.
temperature
,
"max_new_tokens"
:
request
.
max_tokens
,
"min_new_tokens"
:
request
.
min_tokens
,
"stop"
:
request
.
stop
,
"stop_token_ids"
:
request
.
stop_token_ids
,
"top_p"
:
request
.
top_p
,
"top_k"
:
request
.
top_k
,
"min_p"
:
request
.
min_p
,
"presence_penalty"
:
request
.
presence_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
"repetition_penalty"
:
request
.
repetition_penalty
,
"regex"
:
request
.
regex
,
"json_schema"
:
request
.
json_schema
,
"ebnf"
:
request
.
ebnf
,
"n"
:
request
.
n
,
"no_stop_trim"
:
request
.
no_stop_trim
,
"ignore_eos"
:
request
.
ignore_eos
,
"skip_special_tokens"
:
request
.
skip_special_tokens
,
"logit_bias"
:
request
.
logit_bias
,
}
)
return_logprobs
.
append
(
request
.
logprobs
is
not
None
)
logprob_start_lens
.
append
(
current_logprob_start_len
)
top_logprobs_nums
.
append
(
request
.
logprobs
if
request
.
logprobs
is
not
None
else
0
)
return_hidden_states
.
append
(
request
.
return_hidden_states
)
if
len
(
all_requests
)
==
1
:
if
isinstance
(
prompts
[
0
],
str
)
or
isinstance
(
prompts
[
0
][
0
],
str
):
prompt_kwargs
=
{
"text"
:
prompts
[
0
]}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompts
[
0
]}
sampling_params_list
=
sampling_params_list
[
0
]
return_logprobs
=
return_logprobs
[
0
]
logprob_start_lens
=
logprob_start_lens
[
0
]
top_logprobs_nums
=
top_logprobs_nums
[
0
]
lora_paths
=
lora_paths
[
0
]
return_hidden_states
=
return_hidden_states
[
0
]
else
:
if
isinstance
(
prompts
[
0
],
str
)
or
isinstance
(
prompts
[
0
][
0
],
str
):
prompt_kwargs
=
{
"text"
:
prompts
}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompts
}
adapted_request
=
GenerateReqInput
(
**
prompt_kwargs
,
sampling_params
=
sampling_params_list
,
return_logprob
=
return_logprobs
,
top_logprobs_num
=
top_logprobs_nums
,
logprob_start_len
=
logprob_start_lens
,
return_text_in_logprobs
=
True
,
stream
=
all_requests
[
0
].
stream
,
rid
=
request_ids
,
lora_path
=
lora_paths
,
return_hidden_states
=
return_hidden_states
,
bootstrap_host
=
all_requests
[
0
].
bootstrap_host
,
bootstrap_port
=
all_requests
[
0
].
bootstrap_port
,
bootstrap_room
=
all_requests
[
0
].
bootstrap_room
,
)
return
adapted_request
,
all_requests
if
len
(
all_requests
)
>
1
else
all_requests
[
0
]
def
v1_generate_response
(
request
,
ret
,
tokenizer_manager
,
created
,
to_file
=
False
,
cache_report
=
False
):
choices
=
[]
echo
=
False
if
(
not
isinstance
(
request
,
list
))
and
request
.
echo
:
# TODO: handle the case prompt is token ids
if
isinstance
(
request
.
prompt
,
list
)
and
isinstance
(
request
.
prompt
[
0
],
str
):
# for the case of multiple str prompts
prompts
=
request
.
prompt
elif
isinstance
(
request
.
prompt
,
list
)
and
isinstance
(
request
.
prompt
[
0
],
list
):
# for the case of multiple token ids prompts
prompts
=
[
tokenizer_manager
.
tokenizer
.
decode
(
prompt
,
skip_special_tokens
=
True
)
for
prompt
in
request
.
prompt
]
elif
isinstance
(
request
.
prompt
,
list
)
and
isinstance
(
request
.
prompt
[
0
],
int
):
# for the case of single token ids prompt
prompts
=
[
tokenizer_manager
.
tokenizer
.
decode
(
request
.
prompt
,
skip_special_tokens
=
True
)
]
else
:
# for the case of single str prompt
prompts
=
[
request
.
prompt
]
echo
=
True
for
idx
,
ret_item
in
enumerate
(
ret
):
text
=
ret_item
[
"text"
]
if
isinstance
(
request
,
list
)
and
request
[
idx
].
echo
:
echo
=
True
text
=
request
[
idx
].
prompt
+
text
if
echo
and
not
isinstance
(
request
,
list
):
prompt_index
=
idx
//
request
.
n
text
=
prompts
[
prompt_index
]
+
text
logprobs
=
False
if
isinstance
(
request
,
list
)
and
request
[
idx
].
logprobs
is
not
None
:
logprobs
=
True
elif
(
not
isinstance
(
request
,
list
))
and
request
.
logprobs
is
not
None
:
logprobs
=
True
if
logprobs
:
if
echo
:
input_token_logprobs
=
ret_item
[
"meta_info"
][
"input_token_logprobs"
]
input_top_logprobs
=
ret_item
[
"meta_info"
][
"input_top_logprobs"
]
else
:
input_token_logprobs
=
None
input_top_logprobs
=
None
logprobs
=
to_openai_style_logprobs
(
input_token_logprobs
=
input_token_logprobs
,
input_top_logprobs
=
input_top_logprobs
,
output_token_logprobs
=
ret_item
[
"meta_info"
][
"output_token_logprobs"
],
output_top_logprobs
=
ret_item
[
"meta_info"
][
"output_top_logprobs"
],
)
else
:
logprobs
=
None
hidden_states
=
None
if
isinstance
(
request
,
list
)
and
request
[
idx
].
return_hidden_states
:
hidden_states
=
ret_item
[
"meta_info"
].
get
(
"hidden_states"
,
None
)
elif
(
not
isinstance
(
request
,
list
))
and
request
.
return_hidden_states
:
hidden_states
=
ret_item
[
"meta_info"
].
get
(
"hidden_states"
,
None
)
if
hidden_states
is
not
None
:
hidden_states
=
(
hidden_states
[
-
1
]
if
hidden_states
and
len
(
hidden_states
)
>
1
else
[]
)
finish_reason
=
ret_item
[
"meta_info"
][
"finish_reason"
]
if
to_file
:
# to make the choice data json serializable
choice_data
=
{
"index"
:
0
,
"text"
:
text
,
"logprobs"
:
logprobs
,
"finish_reason"
:
finish_reason
[
"type"
]
if
finish_reason
else
None
,
"matched_stop"
:
(
finish_reason
[
"matched"
]
if
finish_reason
and
"matched"
in
finish_reason
else
None
),
}
if
hidden_states
is
not
None
:
choice_data
[
"hidden_states"
]
=
hidden_states
else
:
choice_data
=
CompletionResponseChoice
(
index
=
idx
,
text
=
text
,
logprobs
=
logprobs
,
finish_reason
=
finish_reason
[
"type"
]
if
finish_reason
else
None
,
matched_stop
=
(
finish_reason
[
"matched"
]
if
finish_reason
and
"matched"
in
finish_reason
else
None
),
hidden_states
=
hidden_states
,
)
choices
.
append
(
choice_data
)
if
to_file
:
responses
=
[]
for
i
,
choice
in
enumerate
(
choices
):
response
=
{
"status_code"
:
200
,
"request_id"
:
ret
[
i
][
"meta_info"
][
"id"
],
"body"
:
{
# remain the same but if needed we can change that
"id"
:
ret
[
i
][
"meta_info"
][
"id"
],
"object"
:
"text_completion"
,
"created"
:
created
,
"model"
:
request
[
i
].
model
,
"choices"
:
choice
,
"usage"
:
{
"prompt_tokens"
:
ret
[
i
][
"meta_info"
][
"prompt_tokens"
],
"completion_tokens"
:
ret
[
i
][
"meta_info"
][
"completion_tokens"
],
"total_tokens"
:
ret
[
i
][
"meta_info"
][
"prompt_tokens"
]
+
ret
[
i
][
"meta_info"
][
"completion_tokens"
],
},
"system_fingerprint"
:
None
,
},
}
responses
.
append
(
response
)
return
responses
else
:
prompt_tokens
=
sum
(
ret
[
i
][
"meta_info"
][
"prompt_tokens"
]
for
i
in
range
(
0
,
len
(
ret
),
request
.
n
)
)
completion_tokens
=
sum
(
item
[
"meta_info"
][
"completion_tokens"
]
for
item
in
ret
)
cached_tokens
=
sum
(
item
[
"meta_info"
].
get
(
"cached_tokens"
,
0
)
for
item
in
ret
)
response
=
CompletionResponse
(
id
=
ret
[
0
][
"meta_info"
][
"id"
],
model
=
request
.
model
,
created
=
created
,
choices
=
choices
,
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
prompt_tokens
+
completion_tokens
,
prompt_tokens_details
=
(
{
"cached_tokens"
:
cached_tokens
}
if
cache_report
else
None
),
),
)
return
response
async
def
v1_completions
(
tokenizer_manager
,
raw_request
:
Request
):
try
:
request_json
=
await
raw_request
.
json
()
except
Exception
as
e
:
return
create_error_response
(
"Invalid request body, error: "
,
str
(
e
))
all_requests
=
[
CompletionRequest
(
**
request_json
)]
created
=
int
(
time
.
time
())
adapted_request
,
request
=
v1_generate_request
(
all_requests
)
if
adapted_request
.
stream
:
async
def
generate_stream_resp
():
stream_buffers
=
{}
n_prev_tokens
=
{}
prompt_tokens
=
{}
completion_tokens
=
{}
cached_tokens
=
{}
hidden_states
=
{}
try
:
async
for
content
in
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
):
index
=
content
.
get
(
"index"
,
0
)
stream_buffer
=
stream_buffers
.
get
(
index
,
""
)
n_prev_token
=
n_prev_tokens
.
get
(
index
,
0
)
text
=
content
[
"text"
]
prompt_tokens
[
index
]
=
content
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
[
index
]
=
content
[
"meta_info"
][
"completion_tokens"
]
cached_tokens
[
index
]
=
content
[
"meta_info"
].
get
(
"cached_tokens"
,
0
)
hidden_states
[
index
]
=
content
[
"meta_info"
].
get
(
"hidden_states"
,
None
)
or
hidden_states
.
get
(
index
)
if
not
stream_buffer
:
# The first chunk
if
request
.
echo
:
if
isinstance
(
request
.
prompt
,
str
):
# for the case of single str prompts
prompts
=
request
.
prompt
elif
isinstance
(
request
.
prompt
,
list
):
if
isinstance
(
request
.
prompt
[
0
],
str
):
# for the case of multiple str prompts
prompts
=
request
.
prompt
[
index
//
request
.
n
]
elif
isinstance
(
request
.
prompt
[
0
],
int
):
# for the case of single token ids prompt
prompts
=
tokenizer_manager
.
tokenizer
.
decode
(
request
.
prompt
,
skip_special_tokens
=
True
)
elif
isinstance
(
request
.
prompt
[
0
],
list
)
and
isinstance
(
request
.
prompt
[
0
][
0
],
int
):
# for the case of multiple token ids prompts
prompts
=
tokenizer_manager
.
tokenizer
.
decode
(
request
.
prompt
[
index
//
request
.
n
],
skip_special_tokens
=
True
,
)
# Prepend prompt in response text.
text
=
prompts
+
text
if
request
.
logprobs
is
not
None
:
# The first chunk and echo is enabled.
if
not
stream_buffer
and
request
.
echo
:
input_token_logprobs
=
content
[
"meta_info"
][
"input_token_logprobs"
]
input_top_logprobs
=
content
[
"meta_info"
][
"input_top_logprobs"
]
else
:
input_token_logprobs
=
None
input_top_logprobs
=
None
logprobs
=
to_openai_style_logprobs
(
input_token_logprobs
=
input_token_logprobs
,
input_top_logprobs
=
input_top_logprobs
,
output_token_logprobs
=
content
[
"meta_info"
][
"output_token_logprobs"
][
n_prev_token
:],
output_top_logprobs
=
content
[
"meta_info"
][
"output_top_logprobs"
][
n_prev_token
:],
)
n_prev_token
=
len
(
content
[
"meta_info"
][
"output_token_logprobs"
]
)
else
:
logprobs
=
None
delta
=
text
[
len
(
stream_buffer
)
:]
stream_buffer
=
stream_buffer
+
delta
finish_reason
=
content
[
"meta_info"
][
"finish_reason"
]
choice_data
=
CompletionResponseStreamChoice
(
index
=
index
,
text
=
delta
,
logprobs
=
logprobs
,
finish_reason
=
finish_reason
[
"type"
]
if
finish_reason
else
None
,
matched_stop
=
(
finish_reason
[
"matched"
]
if
finish_reason
and
"matched"
in
finish_reason
else
None
),
)
chunk
=
CompletionStreamResponse
(
id
=
content
[
"meta_info"
][
"id"
],
created
=
created
,
object
=
"text_completion"
,
choices
=
[
choice_data
],
model
=
request
.
model
,
)
stream_buffers
[
index
]
=
stream_buffer
n_prev_tokens
[
index
]
=
n_prev_token
yield
f
"data:
{
chunk
.
model_dump_json
()
}
\n\n
"
if
request
.
return_hidden_states
and
hidden_states
:
for
index
,
choice_hidden_states
in
hidden_states
.
items
():
last_token_hidden_states
=
(
choice_hidden_states
[
-
1
]
if
choice_hidden_states
and
len
(
choice_hidden_states
)
>
1
else
[]
)
hidden_states_chunk
=
CompletionStreamResponse
(
id
=
content
[
"meta_info"
][
"id"
],
created
=
created
,
choices
=
[
CompletionResponseStreamChoice
(
text
=
""
,
index
=
index
,
hidden_states
=
last_token_hidden_states
,
finish_reason
=
None
,
)
],
model
=
request
.
model
,
)
yield
f
"data:
{
hidden_states_chunk
.
model_dump_json
()
}
\n\n
"
if
request
.
stream_options
and
request
.
stream_options
.
include_usage
:
total_prompt_tokens
=
sum
(
tokens
for
i
,
tokens
in
prompt_tokens
.
items
()
if
i
%
request
.
n
==
0
)
total_completion_tokens
=
sum
(
tokens
for
tokens
in
completion_tokens
.
values
()
)
cache_report
=
tokenizer_manager
.
server_args
.
enable_cache_report
if
cache_report
:
cached_tokens_sum
=
sum
(
tokens
for
tokens
in
cached_tokens
.
values
()
)
prompt_tokens_details
=
{
"cached_tokens"
:
cached_tokens_sum
}
else
:
prompt_tokens_details
=
None
usage
=
UsageInfo
(
prompt_tokens
=
total_prompt_tokens
,
completion_tokens
=
total_completion_tokens
,
total_tokens
=
total_prompt_tokens
+
total_completion_tokens
,
prompt_tokens_details
=
prompt_tokens_details
,
)
final_usage_chunk
=
CompletionStreamResponse
(
id
=
content
[
"meta_info"
][
"id"
],
created
=
created
,
choices
=
[],
model
=
request
.
model
,
usage
=
usage
,
)
final_usage_data
=
final_usage_chunk
.
model_dump_json
(
exclude_none
=
True
)
yield
f
"data:
{
final_usage_data
}
\n\n
"
except
ValueError
as
e
:
error
=
create_streaming_error_response
(
str
(
e
))
yield
f
"data:
{
error
}
\n\n
"
yield
"data: [DONE]
\n\n
"
return
StreamingResponse
(
generate_stream_resp
(),
media_type
=
"text/event-stream"
,
background
=
tokenizer_manager
.
create_abort_task
(
adapted_request
),
)
# Non-streaming response.
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
).
__anext__
()
except
ValueError
as
e
:
return
create_error_response
(
str
(
e
))
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
response
=
v1_generate_response
(
request
,
ret
,
tokenizer_manager
,
created
,
cache_report
=
tokenizer_manager
.
server_args
.
enable_cache_report
,
)
return
response
def
_get_enable_thinking_from_request
(
request_obj
):
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs.
Args:
request_obj: The request object (or an item from a list of requests).
Returns:
The boolean value of 'enable_thinking' if found and not True, otherwise True.
"""
if
(
hasattr
(
request_obj
,
"chat_template_kwargs"
)
and
request_obj
.
chat_template_kwargs
and
request_obj
.
chat_template_kwargs
.
get
(
"enable_thinking"
)
is
not
None
):
return
request_obj
.
chat_template_kwargs
.
get
(
"enable_thinking"
)
return
True
def
v1_chat_generate_request
(
all_requests
:
List
[
ChatCompletionRequest
],
tokenizer_manager
,
request_ids
:
List
[
str
]
=
None
,
):
input_ids
=
[]
prompts
=
[]
sampling_params_list
=
[]
image_data_list
=
[]
audio_data_list
=
[]
return_logprobs
=
[]
logprob_start_lens
=
[]
top_logprobs_nums
=
[]
modalities_list
=
[]
lora_paths
=
[]
return_hidden_states
=
[]
# NOTE: with openai API, the prompt's logprobs are always not computed
is_multimodal
=
tokenizer_manager
.
model_config
.
is_multimodal
for
request
in
all_requests
:
# Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string.
# - stop: Custom stop tokens.
# - image_data: None or a list of image strings (URLs or base64 strings).
# - audio_data: None or a list of audio strings (URLs).
# None skips any image processing in GenerateReqInput.
tool_call_constraint
=
None
prompt
=
""
prompt_ids
=
[]
if
not
isinstance
(
request
.
messages
,
str
):
# Apply chat template and its stop strings.
tools
=
None
if
request
.
tools
and
request
.
tool_choice
!=
"none"
:
request
.
skip_special_tokens
=
False
if
not
isinstance
(
request
.
tool_choice
,
str
):
tools
=
[
item
.
function
.
model_dump
()
for
item
in
request
.
tools
if
item
.
function
.
name
==
request
.
tool_choice
.
function
.
name
]
else
:
tools
=
[
item
.
function
.
model_dump
()
for
item
in
request
.
tools
]
tool_call_parser
=
tokenizer_manager
.
server_args
.
tool_call_parser
parser
=
FunctionCallParser
(
request
.
tools
,
tool_call_parser
)
tool_call_constraint
=
parser
.
get_structure_constraint
(
request
.
tool_choice
)
if
chat_template_name
is
None
:
openai_compatible_messages
=
[]
image_data
=
[]
audio_data
=
[]
modalities
=
[]
# Detect template content format by analyzing the jinja template (cached globally)
global
_cached_chat_template
,
_cached_template_format
current_template
=
tokenizer_manager
.
tokenizer
.
chat_template
if
current_template
!=
_cached_chat_template
:
# Template changed or first time - analyze it
_cached_chat_template
=
current_template
_cached_template_format
=
detect_template_content_format
(
current_template
)
logger
.
info
(
f
"Detected chat template content format:
{
_cached_template_format
}
"
)
template_content_format
=
_cached_template_format
for
message
in
request
.
messages
:
if
message
.
content
is
None
:
message
.
content
=
""
msg_dict
=
message
.
model_dump
()
# Process content based on detected template format
processed_msg
=
process_content_for_template_format
(
msg_dict
,
template_content_format
,
image_data
,
audio_data
,
modalities
,
)
openai_compatible_messages
.
append
(
processed_msg
)
# Handle assistant prefix for continue_final_message
if
(
openai_compatible_messages
and
openai_compatible_messages
[
-
1
][
"role"
]
==
"assistant"
):
if
request
.
continue_final_message
:
# Remove the final assistant message so its content can be continued.
assistant_prefix
=
openai_compatible_messages
[
-
1
][
"content"
]
openai_compatible_messages
=
openai_compatible_messages
[:
-
1
]
else
:
assistant_prefix
=
None
else
:
assistant_prefix
=
None
try
:
prompt_ids
=
tokenizer_manager
.
tokenizer
.
apply_chat_template
(
openai_compatible_messages
,
tokenize
=
True
,
add_generation_prompt
=
True
,
tools
=
tools
,
**
(
request
.
chat_template_kwargs
if
request
.
chat_template_kwargs
else
{}
),
)
except
:
# This except branch will be triggered when the chosen model
# has a different tools input format that is not compatible
# with openAI's apply_chat_template tool_call format, like Mistral.
tools
=
[
t
if
"function"
in
t
else
{
"function"
:
t
}
for
t
in
tools
]
prompt_ids
=
tokenizer_manager
.
tokenizer
.
apply_chat_template
(
openai_compatible_messages
,
tokenize
=
True
,
add_generation_prompt
=
True
,
tools
=
tools
,
**
(
request
.
chat_template_kwargs
if
request
.
chat_template_kwargs
else
{}
),
)
if
assistant_prefix
:
encoded
=
tokenizer_manager
.
tokenizer
.
encode
(
assistant_prefix
)
if
(
encoded
and
encoded
[
0
]
==
tokenizer_manager
.
tokenizer
.
bos_token_id
):
encoded
=
encoded
[
1
:]
prompt_ids
+=
encoded
if
is_multimodal
:
prompt
=
tokenizer_manager
.
tokenizer
.
decode
(
prompt_ids
)
stop
=
request
.
stop
image_data
=
image_data
if
image_data
else
None
audio_data
=
audio_data
if
audio_data
else
None
modalities
=
modalities
if
modalities
else
[]
else
:
conv
=
generate_chat_conv
(
request
,
chat_template_name
)
# If we should continue the final assistant message, adjust the conversation.
if
(
request
.
continue_final_message
and
request
.
messages
and
request
.
messages
[
-
1
].
role
==
"assistant"
):
# Remove the auto-added blank assistant turn, if present.
if
conv
.
messages
and
conv
.
messages
[
-
1
][
1
]
is
None
:
conv
.
messages
.
pop
()
# Rebuild the prompt from the conversation.
prompt
=
conv
.
get_prompt
()
# Strip any trailing stop tokens or separators that indicate end-of-assistant.
if
isinstance
(
conv
.
stop_str
,
list
):
for
stop_token
in
conv
.
stop_str
:
if
prompt
.
endswith
(
stop_token
):
prompt
=
prompt
[:
-
len
(
stop_token
)]
elif
isinstance
(
conv
.
stop_str
,
str
)
and
prompt
.
endswith
(
conv
.
stop_str
):
prompt
=
prompt
[:
-
len
(
conv
.
stop_str
)]
if
conv
.
sep
and
prompt
.
endswith
(
conv
.
sep
):
prompt
=
prompt
[:
-
len
(
conv
.
sep
)]
if
getattr
(
conv
,
"sep2"
,
None
)
and
prompt
.
endswith
(
conv
.
sep2
):
prompt
=
prompt
[:
-
len
(
conv
.
sep2
)]
else
:
prompt
=
conv
.
get_prompt
()
image_data
=
conv
.
image_data
audio_data
=
conv
.
audio_data
modalities
=
conv
.
modalities
stop
=
conv
.
stop_str
or
[]
if
not
request
.
ignore_eos
else
[]
if
request
.
stop
:
if
isinstance
(
request
.
stop
,
str
):
stop
.
append
(
request
.
stop
)
else
:
stop
.
extend
(
request
.
stop
)
if
not
is_multimodal
:
prompt_ids
=
tokenizer_manager
.
tokenizer
.
encode
(
prompt
)
else
:
# Use the raw prompt and stop strings if the messages is already a string.
prompt_ids
=
request
.
messages
stop
=
request
.
stop
image_data
=
None
audio_data
=
None
modalities
=
[]
prompt
=
request
.
messages
input_ids
.
append
(
prompt_ids
)
return_logprobs
.
append
(
request
.
logprobs
)
logprob_start_lens
.
append
(
-
1
)
top_logprobs_nums
.
append
(
request
.
top_logprobs
or
0
)
lora_paths
.
append
(
request
.
lora_path
)
prompts
.
append
(
prompt
)
sampling_params
=
{
"temperature"
:
request
.
temperature
,
"max_new_tokens"
:
request
.
max_tokens
or
request
.
max_completion_tokens
,
"min_new_tokens"
:
request
.
min_tokens
,
"stop"
:
stop
,
"stop_token_ids"
:
request
.
stop_token_ids
,
"top_p"
:
request
.
top_p
,
"top_k"
:
request
.
top_k
,
"min_p"
:
request
.
min_p
,
"presence_penalty"
:
request
.
presence_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
"repetition_penalty"
:
request
.
repetition_penalty
,
"regex"
:
request
.
regex
,
"ebnf"
:
request
.
ebnf
,
"n"
:
request
.
n
,
"no_stop_trim"
:
request
.
no_stop_trim
,
"ignore_eos"
:
request
.
ignore_eos
,
"skip_special_tokens"
:
request
.
skip_special_tokens
,
"logit_bias"
:
request
.
logit_bias
,
}
if
request
.
response_format
and
request
.
response_format
.
type
==
"json_schema"
:
sampling_params
[
"json_schema"
]
=
convert_json_schema_to_str
(
request
.
response_format
.
json_schema
.
schema_
)
elif
request
.
response_format
and
request
.
response_format
.
type
==
"json_object"
:
sampling_params
[
"json_schema"
]
=
'{"type": "object"}'
elif
(
request
.
response_format
and
request
.
response_format
.
type
==
"structural_tag"
):
sampling_params
[
"structural_tag"
]
=
convert_json_schema_to_str
(
request
.
response_format
.
model_dump
(
by_alias
=
True
)
)
# Check if there are already existing output constraints
has_existing_constraints
=
(
sampling_params
.
get
(
"regex"
)
or
sampling_params
.
get
(
"ebnf"
)
or
sampling_params
.
get
(
"structural_tag"
)
or
sampling_params
.
get
(
"json_schema"
)
)
if
tool_call_constraint
and
has_existing_constraints
:
logger
.
warning
(
"Constrained decoding is not compatible with tool calls."
)
elif
tool_call_constraint
:
constraint_type
,
constraint_value
=
tool_call_constraint
if
constraint_type
==
"structural_tag"
:
sampling_params
[
constraint_type
]
=
convert_json_schema_to_str
(
constraint_value
.
model_dump
(
by_alias
=
True
)
)
else
:
sampling_params
[
constraint_type
]
=
constraint_value
sampling_params_list
.
append
(
sampling_params
)
image_data_list
.
append
(
image_data
)
audio_data_list
.
append
(
audio_data
)
modalities_list
.
append
(
modalities
)
return_hidden_states
.
append
(
request
.
return_hidden_states
)
if
len
(
all_requests
)
==
1
:
if
is_multimodal
:
# processor will need text input
prompt_kwargs
=
{
"text"
:
prompts
[
0
]}
else
:
if
isinstance
(
input_ids
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
input_ids
[
0
]}
else
:
prompt_kwargs
=
{
"input_ids"
:
input_ids
[
0
]}
sampling_params_list
=
sampling_params_list
[
0
]
image_data_list
=
image_data_list
[
0
]
audio_data_list
=
audio_data_list
[
0
]
return_logprobs
=
return_logprobs
[
0
]
logprob_start_lens
=
logprob_start_lens
[
0
]
top_logprobs_nums
=
top_logprobs_nums
[
0
]
modalities_list
=
modalities_list
[
0
]
lora_paths
=
lora_paths
[
0
]
request_ids
=
request_ids
[
0
]
return_hidden_states
=
return_hidden_states
[
0
]
else
:
if
tokenizer_manager
.
model_config
.
is_multimodal
:
# processor will need text input
prompt_kwargs
=
{
"text"
:
prompts
}
else
:
if
isinstance
(
input_ids
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
input_ids
}
else
:
prompt_kwargs
=
{
"input_ids"
:
input_ids
}
adapted_request
=
GenerateReqInput
(
**
prompt_kwargs
,
image_data
=
image_data_list
,
audio_data
=
audio_data_list
,
sampling_params
=
sampling_params_list
,
return_logprob
=
return_logprobs
,
logprob_start_len
=
logprob_start_lens
,
top_logprobs_num
=
top_logprobs_nums
,
stream
=
all_requests
[
0
].
stream
,
return_text_in_logprobs
=
True
,
rid
=
request_ids
,
modalities
=
modalities_list
,
lora_path
=
lora_paths
,
bootstrap_host
=
all_requests
[
0
].
bootstrap_host
,
bootstrap_port
=
all_requests
[
0
].
bootstrap_port
,
bootstrap_room
=
all_requests
[
0
].
bootstrap_room
,
return_hidden_states
=
return_hidden_states
,
)
return
adapted_request
,
all_requests
if
len
(
all_requests
)
>
1
else
all_requests
[
0
]
def
v1_chat_generate_response
(
request
,
ret
,
created
,
to_file
=
False
,
cache_report
=
False
,
tool_call_parser
=
None
,
reasoning_parser
=
None
,
):
choices
=
[]
for
idx
,
ret_item
in
enumerate
(
ret
):
logprobs
=
False
if
isinstance
(
request
,
list
)
and
request
[
idx
].
logprobs
:
logprobs
=
True
elif
(
not
isinstance
(
request
,
list
))
and
request
.
logprobs
:
logprobs
=
True
if
logprobs
:
logprobs
=
to_openai_style_logprobs
(
output_token_logprobs
=
ret_item
[
"meta_info"
][
"output_token_logprobs"
],
output_top_logprobs
=
ret_item
[
"meta_info"
].
get
(
"output_top_logprobs"
,
None
),
)
token_logprobs
=
[]
for
token_idx
,
(
token
,
logprob
)
in
enumerate
(
zip
(
logprobs
.
tokens
,
logprobs
.
token_logprobs
)
):
token_bytes
=
list
(
token
.
encode
(
"utf-8"
))
top_logprobs
=
[]
if
logprobs
.
top_logprobs
:
for
top_token
,
top_logprob
in
logprobs
.
top_logprobs
[
token_idx
].
items
():
top_token_bytes
=
list
(
top_token
.
encode
(
"utf-8"
))
top_logprobs
.
append
(
TopLogprob
(
token
=
top_token
,
bytes
=
top_token_bytes
,
logprob
=
top_logprob
,
)
)
token_logprobs
.
append
(
ChatCompletionTokenLogprob
(
token
=
token
,
bytes
=
token_bytes
,
logprob
=
logprob
,
top_logprobs
=
top_logprobs
,
)
)
choice_logprobs
=
ChoiceLogprobs
(
content
=
token_logprobs
)
else
:
choice_logprobs
=
None
if
isinstance
(
request
,
list
)
and
request
[
idx
].
return_hidden_states
:
include_hidden_states
=
True
elif
not
isinstance
(
request
,
list
)
and
request
.
return_hidden_states
:
include_hidden_states
=
True
else
:
include_hidden_states
=
False
if
include_hidden_states
and
ret_item
[
"meta_info"
].
get
(
"hidden_states"
,
None
):
hidden_states
=
ret_item
[
"meta_info"
][
"hidden_states"
]
hidden_states
=
(
hidden_states
[
-
1
]
if
hidden_states
and
len
(
hidden_states
)
>
1
else
[]
)
else
:
hidden_states
=
None
finish_reason
=
ret_item
[
"meta_info"
][
"finish_reason"
]
tool_calls
=
None
text
=
ret_item
[
"text"
]
if
isinstance
(
request
,
list
):
tool_choice
=
request
[
idx
].
tool_choice
tools
=
request
[
idx
].
tools
separate_reasoning
=
request
[
idx
].
separate_reasoning
enable_thinking
=
_get_enable_thinking_from_request
(
request
[
idx
])
else
:
tool_choice
=
request
.
tool_choice
tools
=
request
.
tools
separate_reasoning
=
request
.
separate_reasoning
enable_thinking
=
_get_enable_thinking_from_request
(
request
)
reasoning_text
=
None
if
reasoning_parser
and
separate_reasoning
and
enable_thinking
:
try
:
parser
=
ReasoningParser
(
model_type
=
reasoning_parser
,
stream_reasoning
=
False
)
reasoning_text
,
text
=
parser
.
parse_non_stream
(
text
)
except
Exception
as
e
:
logger
.
error
(
f
"Exception:
{
e
}
"
)
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"Failed to parse reasoning related info to json format!"
,
)
if
tool_choice
!=
"none"
and
tools
:
parser
=
FunctionCallParser
(
tools
,
tool_call_parser
)
if
parser
.
has_tool_call
(
text
):
if
finish_reason
[
"type"
]
==
"stop"
:
finish_reason
[
"type"
]
=
"tool_calls"
finish_reason
[
"matched"
]
=
None
try
:
text
,
call_info_list
=
parser
.
parse_non_stream
(
text
)
tool_calls
=
[
ToolCall
(
id
=
f
"call_
{
base64
.
urlsafe_b64encode
(
uuid
.
uuid4
().
bytes
).
rstrip
(
b
'='
).
decode
()
}
"
,
function
=
FunctionResponse
(
name
=
call_info
.
name
,
arguments
=
call_info
.
parameters
),
)
for
call_info
in
call_info_list
]
except
Exception
as
e
:
logger
.
error
(
f
"Exception:
{
e
}
"
)
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"Failed to parse fc related info to json format!"
,
)
if
to_file
:
# to make the choice data json serializable
choice_data
=
{
"index"
:
0
,
"message"
:
{
"role"
:
"assistant"
,
"content"
:
text
if
text
else
None
,
"tool_calls"
:
tool_calls
,
"reasoning_content"
:
reasoning_text
if
reasoning_text
else
None
,
},
"logprobs"
:
choice_logprobs
.
model_dump
()
if
choice_logprobs
else
None
,
"finish_reason"
:
finish_reason
[
"type"
]
if
finish_reason
else
None
,
"matched_stop"
:
(
finish_reason
[
"matched"
]
if
finish_reason
and
"matched"
in
finish_reason
else
None
),
}
if
hidden_states
is
not
None
:
choice_data
[
"hidden_states"
]
=
hidden_states
else
:
choice_data
=
ChatCompletionResponseChoice
(
index
=
idx
,
message
=
ChatMessage
(
role
=
"assistant"
,
content
=
text
if
text
else
None
,
tool_calls
=
tool_calls
,
reasoning_content
=
reasoning_text
if
reasoning_text
else
None
,
),
logprobs
=
choice_logprobs
,
finish_reason
=
finish_reason
[
"type"
]
if
finish_reason
else
None
,
matched_stop
=
(
finish_reason
[
"matched"
]
if
finish_reason
and
"matched"
in
finish_reason
else
None
),
hidden_states
=
hidden_states
,
)
choices
.
append
(
choice_data
)
if
to_file
:
responses
=
[]
for
i
,
choice
in
enumerate
(
choices
):
response
=
{
"status_code"
:
200
,
"request_id"
:
ret
[
i
][
"meta_info"
][
"id"
],
"body"
:
{
# remain the same but if needed we can change that
"id"
:
ret
[
i
][
"meta_info"
][
"id"
],
"object"
:
"chat.completion"
,
"created"
:
created
,
"model"
:
(
request
[
i
].
model
if
isinstance
(
request
,
list
)
else
request
.
model
),
"choices"
:
choice
,
"usage"
:
{
"prompt_tokens"
:
ret
[
i
][
"meta_info"
][
"prompt_tokens"
],
"completion_tokens"
:
ret
[
i
][
"meta_info"
][
"completion_tokens"
],
"total_tokens"
:
ret
[
i
][
"meta_info"
][
"prompt_tokens"
]
+
ret
[
i
][
"meta_info"
][
"completion_tokens"
],
},
"system_fingerprint"
:
None
,
},
}
responses
.
append
(
response
)
return
responses
else
:
prompt_tokens
=
sum
(
ret
[
i
][
"meta_info"
][
"prompt_tokens"
]
for
i
in
range
(
0
,
len
(
ret
),
request
.
n
)
)
completion_tokens
=
sum
(
item
[
"meta_info"
][
"completion_tokens"
]
for
item
in
ret
)
cached_tokens
=
sum
(
item
[
"meta_info"
].
get
(
"cached_tokens"
,
0
)
for
item
in
ret
)
response
=
ChatCompletionResponse
(
id
=
ret
[
0
][
"meta_info"
][
"id"
],
created
=
created
,
model
=
request
.
model
,
choices
=
choices
,
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
prompt_tokens
+
completion_tokens
,
prompt_tokens_details
=
(
{
"cached_tokens"
:
cached_tokens
}
if
cache_report
else
None
),
),
)
return
response
async
def
v1_chat_completions
(
tokenizer_manager
,
raw_request
:
Request
,
cache_report
=
False
):
try
:
request_json
=
await
raw_request
.
json
()
except
Exception
as
e
:
return
create_error_response
(
"Invalid request body, error: "
,
str
(
e
))
all_requests
=
[
ChatCompletionRequest
(
**
request_json
)]
created
=
int
(
time
.
time
())
adapted_request
,
request
=
v1_chat_generate_request
(
all_requests
,
tokenizer_manager
,
request_ids
=
[
all_requests
[
0
].
rid
]
)
if
adapted_request
.
stream
:
parser_dict
=
{}
reasoning_parser_dict
=
{}
async
def
generate_stream_resp
():
tool_index_previous
=
-
1
is_firsts
=
{}
stream_buffers
=
{}
n_prev_tokens
=
{}
prompt_tokens
=
{}
completion_tokens
=
{}
cached_tokens
=
{}
hidden_states
=
{}
try
:
async
for
content
in
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
):
index
=
content
.
get
(
"index"
,
0
)
text
=
content
[
"text"
]
hidden_states
[
index
]
=
content
[
"meta_info"
].
get
(
"hidden_states"
,
None
)
or
hidden_states
.
get
(
index
)
is_first
=
is_firsts
.
get
(
index
,
True
)
stream_buffer
=
stream_buffers
.
get
(
index
,
""
)
n_prev_token
=
n_prev_tokens
.
get
(
index
,
0
)
prompt_tokens
[
index
]
=
content
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
[
index
]
=
content
[
"meta_info"
][
"completion_tokens"
]
cached_tokens
[
index
]
=
content
[
"meta_info"
].
get
(
"cached_tokens"
,
0
)
if
request
.
logprobs
:
logprobs
=
to_openai_style_logprobs
(
output_token_logprobs
=
content
[
"meta_info"
][
"output_token_logprobs"
][
n_prev_token
:],
output_top_logprobs
=
content
[
"meta_info"
].
get
(
"output_top_logprobs"
,
[]
)[
n_prev_token
:],
)
n_prev_token
=
len
(
content
[
"meta_info"
][
"output_token_logprobs"
]
)
token_logprobs
=
[]
for
token
,
logprob
in
zip
(
logprobs
.
tokens
,
logprobs
.
token_logprobs
):
token_bytes
=
list
(
token
.
encode
(
"utf-8"
))
top_logprobs
=
[]
if
logprobs
.
top_logprobs
:
for
top_token
,
top_logprob
in
logprobs
.
top_logprobs
[
0
].
items
():
top_token_bytes
=
list
(
top_token
.
encode
(
"utf-8"
))
top_logprobs
.
append
(
TopLogprob
(
token
=
top_token
,
bytes
=
top_token_bytes
,
logprob
=
top_logprob
,
)
)
token_logprobs
.
append
(
ChatCompletionTokenLogprob
(
token
=
token
,
bytes
=
token_bytes
,
logprob
=
logprob
,
top_logprobs
=
top_logprobs
,
)
)
choice_logprobs
=
ChoiceLogprobs
(
content
=
token_logprobs
)
else
:
choice_logprobs
=
None
finish_reason
=
content
[
"meta_info"
][
"finish_reason"
]
finish_reason_type
=
(
finish_reason
[
"type"
]
if
finish_reason
else
None
)
if
is_first
:
# First chunk with role
is_first
=
False
delta
=
DeltaMessage
(
role
=
"assistant"
)
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
index
,
delta
=
delta
,
finish_reason
=
finish_reason_type
,
matched_stop
=
(
finish_reason
[
"matched"
]
if
finish_reason
and
"matched"
in
finish_reason
else
None
),
logprobs
=
choice_logprobs
,
)
chunk
=
ChatCompletionStreamResponse
(
id
=
content
[
"meta_info"
][
"id"
],
created
=
created
,
choices
=
[
choice_data
],
model
=
request
.
model
,
)
yield
f
"data:
{
chunk
.
model_dump_json
()
}
\n\n
"
text
=
content
[
"text"
]
delta
=
text
[
len
(
stream_buffer
)
:]
new_stream_buffer
=
stream_buffer
+
delta
enable_thinking
=
_get_enable_thinking_from_request
(
request
)
if
(
tokenizer_manager
.
server_args
.
reasoning_parser
and
request
.
separate_reasoning
and
enable_thinking
):
if
index
not
in
reasoning_parser_dict
:
reasoning_parser_dict
[
index
]
=
ReasoningParser
(
tokenizer_manager
.
server_args
.
reasoning_parser
,
request
.
stream_reasoning
,
)
reasoning_parser
=
reasoning_parser_dict
[
index
]
reasoning_text
,
delta
=
reasoning_parser
.
parse_stream_chunk
(
delta
)
if
reasoning_text
:
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
index
,
delta
=
DeltaMessage
(
reasoning_content
=
(
reasoning_text
if
reasoning_text
else
None
)
),
finish_reason
=
finish_reason_type
,
)
chunk
=
ChatCompletionStreamResponse
(
id
=
content
[
"meta_info"
][
"id"
],
created
=
created
,
choices
=
[
choice_data
],
model
=
request
.
model
,
)
yield
f
"data:
{
chunk
.
model_dump_json
()
}
\n\n
"
if
(
delta
and
len
(
delta
)
==
0
)
or
not
delta
:
stream_buffers
[
index
]
=
new_stream_buffer
is_firsts
[
index
]
=
is_first
n_prev_tokens
[
index
]
=
n_prev_token
continue
if
request
.
tool_choice
!=
"none"
and
request
.
tools
:
if
index
not
in
parser_dict
:
parser_dict
[
index
]
=
FunctionCallParser
(
tools
=
request
.
tools
,
tool_call_parser
=
tokenizer_manager
.
server_args
.
tool_call_parser
,
)
parser
=
parser_dict
[
index
]
# parse_increment => returns (normal_text, calls)
normal_text
,
calls
=
parser
.
parse_stream_chunk
(
delta
)
# 1) if there's normal_text, output it as normal content
if
normal_text
:
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
index
,
delta
=
DeltaMessage
(
content
=
normal_text
if
normal_text
else
None
),
finish_reason
=
finish_reason_type
,
)
chunk
=
ChatCompletionStreamResponse
(
id
=
content
[
"meta_info"
][
"id"
],
created
=
created
,
choices
=
[
choice_data
],
model
=
request
.
model
,
)
yield
f
"data:
{
chunk
.
model_dump_json
()
}
\n\n
"
# 2) if we found calls, we output them as separate chunk(s)
for
call_item
in
calls
:
tool_index_current
=
call_item
.
tool_index
# transform call_item -> FunctionResponse + ToolCall
if
finish_reason_type
==
"stop"
:
latest_delta_len
=
0
if
isinstance
(
call_item
.
parameters
,
str
):
latest_delta_len
=
len
(
call_item
.
parameters
)
expected_call
=
json
.
dumps
(
parser
.
detector
.
prev_tool_call_arr
[
index
].
get
(
"arguments"
,
{}
),
ensure_ascii
=
False
,
)
actual_call
=
parser
.
detector
.
streamed_args_for_tool
[
index
]
if
latest_delta_len
>
0
:
actual_call
=
actual_call
[:
-
latest_delta_len
]
remaining_call
=
expected_call
.
replace
(
actual_call
,
""
,
1
)
call_item
.
parameters
=
remaining_call
finish_reason_type
=
"tool_calls"
tool_call
=
ToolCall
(
id
=
(
f
"call_
{
base64
.
urlsafe_b64encode
(
uuid
.
uuid4
().
bytes
).
rstrip
(
b
'='
).
decode
()
}
"
if
tool_index_previous
!=
tool_index_current
else
None
),
index
=
call_item
.
tool_index
,
function
=
FunctionResponse
(
name
=
call_item
.
name
,
arguments
=
call_item
.
parameters
,
),
)
tool_index_previous
=
tool_index_current
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
index
,
delta
=
DeltaMessage
(
tool_calls
=
[
tool_call
]),
finish_reason
=
(
None
if
request
.
stream_options
and
request
.
stream_options
.
include_usage
else
finish_reason_type
),
# additional chunk will be return
)
chunk
=
ChatCompletionStreamResponse
(
id
=
content
[
"meta_info"
][
"id"
],
created
=
created
,
choices
=
[
choice_data
],
model
=
request
.
model
,
)
yield
f
"data:
{
chunk
.
model_dump_json
()
}
\n\n
"
stream_buffers
[
index
]
=
new_stream_buffer
is_firsts
[
index
]
=
is_first
n_prev_tokens
[
index
]
=
n_prev_token
else
:
# No tool calls => just treat this as normal text
if
delta
or
not
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
index
,
delta
=
DeltaMessage
(
content
=
delta
if
delta
else
None
),
finish_reason
=
(
None
if
request
.
stream_options
and
request
.
stream_options
.
include_usage
else
finish_reason_type
),
matched_stop
=
(
finish_reason
[
"matched"
]
if
finish_reason
and
"matched"
in
finish_reason
else
None
),
logprobs
=
choice_logprobs
,
)
chunk
=
ChatCompletionStreamResponse
(
id
=
content
[
"meta_info"
][
"id"
],
created
=
created
,
choices
=
[
choice_data
],
model
=
request
.
model
,
)
yield
f
"data:
{
chunk
.
model_dump_json
()
}
\n\n
"
stream_buffers
[
index
]
=
new_stream_buffer
is_firsts
[
index
]
=
is_first
n_prev_tokens
[
index
]
=
n_prev_token
if
finish_reason_type
==
"stop"
and
request
.
tool_choice
!=
"none"
:
parser
=
FunctionCallParser
(
tools
=
request
.
tools
,
tool_call_parser
=
tokenizer_manager
.
server_args
.
tool_call_parser
,
)
if
parser
.
has_tool_call
(
new_stream_buffer
):
# if the stream ends with empty string after tool calls
finish_reason_type
=
"tool_calls"
if
request
.
stream_options
and
request
.
stream_options
.
include_usage
:
total_prompt_tokens
=
sum
(
tokens
for
i
,
tokens
in
prompt_tokens
.
items
()
if
i
%
request
.
n
==
0
)
total_completion_tokens
=
sum
(
tokens
for
tokens
in
completion_tokens
.
values
()
)
cache_report
=
tokenizer_manager
.
server_args
.
enable_cache_report
if
cache_report
:
cached_tokens_sum
=
sum
(
tokens
for
tokens
in
cached_tokens
.
values
()
)
prompt_tokens_details
=
{
"cached_tokens"
:
cached_tokens_sum
}
else
:
prompt_tokens_details
=
None
usage
=
UsageInfo
(
prompt_tokens
=
total_prompt_tokens
,
completion_tokens
=
total_completion_tokens
,
total_tokens
=
total_prompt_tokens
+
total_completion_tokens
,
prompt_tokens_details
=
prompt_tokens_details
,
)
else
:
usage
=
None
if
request
.
return_hidden_states
and
hidden_states
:
for
index
,
choice_hidden_states
in
hidden_states
.
items
():
last_token_hidden_states
=
(
choice_hidden_states
[
-
1
]
if
choice_hidden_states
and
len
(
choice_hidden_states
)
>
1
else
[]
)
hidden_states_chunk
=
ChatCompletionStreamResponse
(
id
=
content
[
"meta_info"
][
"id"
],
created
=
created
,
choices
=
[
ChatCompletionResponseStreamChoice
(
index
=
index
,
delta
=
DeltaMessage
(
hidden_states
=
last_token_hidden_states
),
finish_reason
=
finish_reason_type
,
)
],
model
=
request
.
model
,
)
yield
f
"data:
{
hidden_states_chunk
.
model_dump_json
()
}
\n\n
"
final_usage_chunk
=
ChatCompletionStreamResponse
(
id
=
content
[
"meta_info"
][
"id"
],
created
=
created
,
choices
=
[
ChatCompletionResponseStreamChoice
(
index
=
index
,
delta
=
DeltaMessage
(),
finish_reason
=
finish_reason_type
,
)
],
model
=
request
.
model
,
usage
=
usage
,
)
yield
f
"data:
{
final_usage_chunk
.
model_dump_json
()
}
\n\n
"
except
ValueError
as
e
:
error
=
create_streaming_error_response
(
str
(
e
))
yield
f
"data:
{
error
}
\n\n
"
yield
"data: [DONE]
\n\n
"
return
StreamingResponse
(
generate_stream_resp
(),
media_type
=
"text/event-stream"
,
background
=
tokenizer_manager
.
create_abort_task
(
adapted_request
),
)
# Non-streaming response.
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
).
__anext__
()
except
ValueError
as
e
:
return
create_error_response
(
str
(
e
))
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
response
=
v1_chat_generate_response
(
request
,
ret
,
created
,
cache_report
=
tokenizer_manager
.
server_args
.
enable_cache_report
,
tool_call_parser
=
tokenizer_manager
.
server_args
.
tool_call_parser
,
reasoning_parser
=
tokenizer_manager
.
server_args
.
reasoning_parser
,
)
return
response
def
v1_embedding_request
(
all_requests
,
tokenizer_manager
):
prompts
=
[]
sampling_params_list
=
[]
first_prompt_type
=
type
(
all_requests
[
0
].
input
)
for
request
in
all_requests
:
prompt
=
request
.
input
# Check for empty/whitespace string
prompt
=
_validate_prompt
(
request
.
input
)
assert
(
type
(
prompt
)
is
first_prompt_type
),
"All prompts must be of the same type in file input settings"
prompts
.
append
(
prompt
)
if
len
(
all_requests
)
==
1
:
prompt
=
prompts
[
0
]
if
isinstance
(
prompt
,
str
)
or
isinstance
(
prompt
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
prompt
}
elif
isinstance
(
prompt
,
list
)
and
isinstance
(
prompt
[
0
],
MultimodalEmbeddingInput
):
texts
=
[]
images
=
[]
for
item
in
prompt
:
# TODO simply use padding for text, we should use a better way to handle this
texts
.
append
(
item
.
text
if
item
.
text
is
not
None
else
"padding"
)
images
.
append
(
item
.
image
if
item
.
image
is
not
None
else
None
)
generate_prompts
=
[]
if
chat_template_name
is
not
None
:
convs
=
generate_embedding_convs
(
texts
,
images
,
chat_template_name
)
for
conv
in
convs
:
generate_prompts
.
append
(
conv
.
get_prompt
())
else
:
generate_prompts
=
texts
if
len
(
generate_prompts
)
==
1
:
prompt_kwargs
=
{
"text"
:
generate_prompts
[
0
],
"image_data"
:
images
[
0
]}
else
:
prompt_kwargs
=
{
"text"
:
generate_prompts
,
"image_data"
:
images
}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompt
}
request_ids
=
all_requests
[
0
].
rid
else
:
if
isinstance
(
prompts
[
0
],
str
)
or
isinstance
(
prompts
[
0
][
0
],
str
):
prompt_kwargs
=
{
"text"
:
prompts
}
elif
isinstance
(
prompts
[
0
],
list
)
and
isinstance
(
prompts
[
0
][
0
],
MultimodalEmbeddingInput
):
# TODO: multiple requests
raise
NotImplementedError
(
"Multiple requests with multimodal inputs are not supported yet"
)
else
:
prompt_kwargs
=
{
"input_ids"
:
prompts
}
request_ids
=
[
req
.
rid
for
req
in
all_requests
]
adapted_request
=
EmbeddingReqInput
(
rid
=
request_ids
,
**
prompt_kwargs
,
)
if
len
(
all_requests
)
==
1
:
return
adapted_request
,
all_requests
[
0
]
return
adapted_request
,
all_requests
def
v1_embedding_response
(
ret
,
model_path
,
to_file
=
False
):
embedding_objects
=
[]
prompt_tokens
=
0
for
idx
,
ret_item
in
enumerate
(
ret
):
embedding_objects
.
append
(
EmbeddingObject
(
embedding
=
ret
[
idx
][
"embedding"
],
index
=
idx
,
)
)
prompt_tokens
+=
ret
[
idx
][
"meta_info"
][
"prompt_tokens"
]
return
EmbeddingResponse
(
data
=
embedding_objects
,
model
=
model_path
,
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
total_tokens
=
prompt_tokens
,
),
)
async
def
v1_embeddings
(
tokenizer_manager
,
raw_request
:
Request
):
try
:
request_json
=
await
raw_request
.
json
()
except
Exception
as
e
:
return
create_error_response
(
"Invalid request body, error: "
,
str
(
e
))
all_requests
=
[
EmbeddingRequest
(
**
request_json
)]
adapted_request
,
request
=
v1_embedding_request
(
all_requests
,
tokenizer_manager
)
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
).
__anext__
()
except
ValueError
as
e
:
return
create_error_response
(
str
(
e
))
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
response
=
v1_embedding_response
(
ret
,
tokenizer_manager
.
model_path
)
return
response
def
v1_rerank_request
(
obj
:
V1RerankReqInput
):
if
obj
.
query
is
None
:
raise
ValueError
(
"query is required"
)
if
obj
.
documents
is
None
or
len
(
obj
.
documents
)
==
0
:
raise
ValueError
(
"documents is required"
)
pairs
=
[]
for
doc
in
obj
.
documents
:
pairs
.
append
([
obj
.
query
,
doc
])
adapted_request
=
EmbeddingReqInput
(
text
=
pairs
,
is_cross_encoder_request
=
True
,
)
return
adapted_request
def
v1_rerank_response
(
ret
,
obj
:
V1RerankReqInput
):
response
=
[]
for
idx
,
ret_item
in
enumerate
(
ret
):
response
.
append
(
RerankResponse
(
score
=
ret
[
idx
][
"embedding"
],
document
=
obj
.
documents
[
idx
],
index
=
idx
,
meta_info
=
ret
[
idx
][
"meta_info"
],
)
)
response
.
sort
(
key
=
lambda
x
:
x
.
score
,
reverse
=
True
)
return
response
async
def
v1_rerank
(
tokenizer_manager
,
obj
:
V1RerankReqInput
,
raw_request
:
Request
):
adapted_request
=
v1_rerank_request
(
obj
)
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
).
__anext__
()
except
ValueError
as
e
:
return
create_error_response
(
str
(
e
))
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
response
=
v1_rerank_response
(
ret
,
obj
,
)
return
response
def
to_openai_style_logprobs
(
input_token_logprobs
=
None
,
output_token_logprobs
=
None
,
input_top_logprobs
=
None
,
output_top_logprobs
=
None
,
):
ret_logprobs
=
LogProbs
()
def
append_token_logprobs
(
token_logprobs
):
for
logprob
,
_
,
token_text
in
token_logprobs
:
ret_logprobs
.
tokens
.
append
(
token_text
)
ret_logprobs
.
token_logprobs
.
append
(
logprob
)
# Not supported yet
ret_logprobs
.
text_offset
.
append
(
-
1
)
def
append_top_logprobs
(
top_logprobs
):
for
tokens
in
top_logprobs
:
if
tokens
is
not
None
:
ret_logprobs
.
top_logprobs
.
append
(
{
token
[
2
]:
token
[
0
]
for
token
in
tokens
}
)
else
:
ret_logprobs
.
top_logprobs
.
append
(
None
)
if
input_token_logprobs
is
not
None
:
append_token_logprobs
(
input_token_logprobs
)
if
output_token_logprobs
is
not
None
:
append_token_logprobs
(
output_token_logprobs
)
if
input_top_logprobs
is
not
None
:
append_top_logprobs
(
input_top_logprobs
)
if
output_top_logprobs
is
not
None
:
append_top_logprobs
(
output_top_logprobs
)
return
ret_logprobs
async
def
v1_score
(
tokenizer_manager
,
raw_request
):
try
:
# Parse request
request_data
=
await
raw_request
.
json
()
request
=
ScoringRequest
(
**
request_data
)
# Use tokenizer_manager's score_request method directly
scores
=
await
tokenizer_manager
.
score_request
(
query
=
request
.
query
,
items
=
request
.
items
,
label_token_ids
=
request
.
label_token_ids
,
apply_softmax
=
request
.
apply_softmax
,
item_first
=
request
.
item_first
,
request
=
request
,
)
# Create response with just the scores, without usage info
response
=
ScoringResponse
(
scores
=
scores
,
model
=
request
.
model
,
)
return
response
except
Exception
as
e
:
logger
.
error
(
f
"Error in v1_score:
{
str
(
e
)
}
"
)
return
create_error_response
(
str
(
e
))
python/sglang/srt/openai_api/protocol.py
deleted
100644 → 0
View file @
02bf31ef
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pydantic models for OpenAI API protocol"""
import
time
from
typing
import
Dict
,
List
,
Optional
,
Union
from
pydantic
import
BaseModel
,
Field
,
model_serializer
,
root_validator
from
typing_extensions
import
Literal
class
ModelCard
(
BaseModel
):
"""Model cards."""
id
:
str
object
:
str
=
"model"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
owned_by
:
str
=
"sglang"
root
:
Optional
[
str
]
=
None
max_model_len
:
Optional
[
int
]
=
None
class
ModelList
(
BaseModel
):
"""Model list consists of model cards."""
object
:
str
=
"list"
data
:
List
[
ModelCard
]
=
Field
(
default_factory
=
list
)
class
ErrorResponse
(
BaseModel
):
object
:
str
=
"error"
message
:
str
type
:
str
param
:
Optional
[
str
]
=
None
code
:
int
class
LogProbs
(
BaseModel
):
text_offset
:
List
[
int
]
=
Field
(
default_factory
=
list
)
token_logprobs
:
List
[
Optional
[
float
]]
=
Field
(
default_factory
=
list
)
tokens
:
List
[
str
]
=
Field
(
default_factory
=
list
)
top_logprobs
:
List
[
Optional
[
Dict
[
str
,
float
]]]
=
Field
(
default_factory
=
list
)
class
TopLogprob
(
BaseModel
):
token
:
str
bytes
:
List
[
int
]
logprob
:
float
class
ChatCompletionTokenLogprob
(
BaseModel
):
token
:
str
bytes
:
List
[
int
]
logprob
:
float
top_logprobs
:
List
[
TopLogprob
]
class
ChoiceLogprobs
(
BaseModel
):
# build for v1/chat/completions response
content
:
List
[
ChatCompletionTokenLogprob
]
class
UsageInfo
(
BaseModel
):
prompt_tokens
:
int
=
0
total_tokens
:
int
=
0
completion_tokens
:
Optional
[
int
]
=
0
# only used to return cached tokens when --enable-cache-report is set
prompt_tokens_details
:
Optional
[
Dict
[
str
,
int
]]
=
None
class
StreamOptions
(
BaseModel
):
include_usage
:
Optional
[
bool
]
=
False
class
JsonSchemaResponseFormat
(
BaseModel
):
name
:
str
description
:
Optional
[
str
]
=
None
# use alias to workaround pydantic conflict
schema_
:
Optional
[
Dict
[
str
,
object
]]
=
Field
(
alias
=
"schema"
,
default
=
None
)
strict
:
Optional
[
bool
]
=
False
class
FileRequest
(
BaseModel
):
# https://platform.openai.com/docs/api-reference/files/create
file
:
bytes
# The File object (not file name) to be uploaded
purpose
:
str
=
(
"batch"
# The intended purpose of the uploaded file, default is "batch"
)
class
FileResponse
(
BaseModel
):
id
:
str
object
:
str
=
"file"
bytes
:
int
created_at
:
int
filename
:
str
purpose
:
str
class
FileDeleteResponse
(
BaseModel
):
id
:
str
object
:
str
=
"file"
deleted
:
bool
class
BatchRequest
(
BaseModel
):
input_file_id
:
(
str
# The ID of an uploaded file that contains requests for the new batch
)
endpoint
:
str
# The endpoint to be used for all requests in the batch
completion_window
:
str
# The time frame within which the batch should be processed
metadata
:
Optional
[
dict
]
=
None
# Optional custom metadata for the batch
class
BatchResponse
(
BaseModel
):
id
:
str
object
:
str
=
"batch"
endpoint
:
str
errors
:
Optional
[
dict
]
=
None
input_file_id
:
str
completion_window
:
str
status
:
str
=
"validating"
output_file_id
:
Optional
[
str
]
=
None
error_file_id
:
Optional
[
str
]
=
None
created_at
:
int
in_progress_at
:
Optional
[
int
]
=
None
expires_at
:
Optional
[
int
]
=
None
finalizing_at
:
Optional
[
int
]
=
None
completed_at
:
Optional
[
int
]
=
None
failed_at
:
Optional
[
int
]
=
None
expired_at
:
Optional
[
int
]
=
None
cancelling_at
:
Optional
[
int
]
=
None
cancelled_at
:
Optional
[
int
]
=
None
request_counts
:
Optional
[
dict
]
=
None
metadata
:
Optional
[
dict
]
=
None
class
CompletionRequest
(
BaseModel
):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model
:
str
prompt
:
Union
[
List
[
int
],
List
[
List
[
int
]],
str
,
List
[
str
]]
best_of
:
Optional
[
int
]
=
None
echo
:
bool
=
False
frequency_penalty
:
float
=
0.0
logit_bias
:
Optional
[
Dict
[
str
,
float
]]
=
None
logprobs
:
Optional
[
int
]
=
None
max_tokens
:
int
=
16
n
:
int
=
1
presence_penalty
:
float
=
0.0
seed
:
Optional
[
int
]
=
None
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
stream
:
bool
=
False
stream_options
:
Optional
[
StreamOptions
]
=
None
suffix
:
Optional
[
str
]
=
None
temperature
:
float
=
1.0
top_p
:
float
=
1.0
user
:
Optional
[
str
]
=
None
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k
:
int
=
-
1
min_p
:
float
=
0.0
min_tokens
:
int
=
0
json_schema
:
Optional
[
str
]
=
None
regex
:
Optional
[
str
]
=
None
ebnf
:
Optional
[
str
]
=
None
repetition_penalty
:
float
=
1.0
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
no_stop_trim
:
bool
=
False
ignore_eos
:
bool
=
False
skip_special_tokens
:
bool
=
True
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
session_params
:
Optional
[
Dict
]
=
None
return_hidden_states
:
Optional
[
bool
]
=
False
# For PD disaggregation
bootstrap_host
:
Optional
[
str
]
=
None
bootstrap_port
:
Optional
[
int
]
=
None
bootstrap_room
:
Optional
[
int
]
=
None
class
CompletionResponseChoice
(
BaseModel
):
index
:
int
text
:
str
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Literal
[
"stop"
,
"length"
,
"content_filter"
,
"abort"
]
matched_stop
:
Union
[
None
,
int
,
str
]
=
None
hidden_states
:
Optional
[
object
]
=
None
@
model_serializer
def
_serialize
(
self
):
return
exclude_if_none
(
self
,
[
"hidden_states"
])
class
CompletionResponse
(
BaseModel
):
id
:
str
object
:
str
=
"text_completion"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
choices
:
List
[
CompletionResponseChoice
]
usage
:
UsageInfo
class
CompletionResponseStreamChoice
(
BaseModel
):
index
:
int
text
:
str
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
,
"content_filter"
]]
=
None
matched_stop
:
Union
[
None
,
int
,
str
]
=
None
hidden_states
:
Optional
[
object
]
=
None
@
model_serializer
def
_serialize
(
self
):
return
exclude_if_none
(
self
,
[
"hidden_states"
])
class
CompletionStreamResponse
(
BaseModel
):
id
:
str
object
:
str
=
"text_completion"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
choices
:
List
[
CompletionResponseStreamChoice
]
usage
:
Optional
[
UsageInfo
]
=
None
class
ChatCompletionMessageContentTextPart
(
BaseModel
):
type
:
Literal
[
"text"
]
text
:
str
class
ChatCompletionMessageContentImageURL
(
BaseModel
):
url
:
str
detail
:
Optional
[
Literal
[
"auto"
,
"low"
,
"high"
]]
=
"auto"
class
ChatCompletionMessageContentAudioURL
(
BaseModel
):
url
:
str
class
ChatCompletionMessageContentImagePart
(
BaseModel
):
type
:
Literal
[
"image_url"
]
image_url
:
ChatCompletionMessageContentImageURL
modalities
:
Optional
[
Literal
[
"image"
,
"multi-images"
,
"video"
]]
=
"image"
class
ChatCompletionMessageContentAudioPart
(
BaseModel
):
type
:
Literal
[
"audio_url"
]
audio_url
:
ChatCompletionMessageContentAudioURL
ChatCompletionMessageContentPart
=
Union
[
ChatCompletionMessageContentTextPart
,
ChatCompletionMessageContentImagePart
,
ChatCompletionMessageContentAudioPart
,
]
class
FunctionResponse
(
BaseModel
):
"""Function response."""
name
:
Optional
[
str
]
=
None
arguments
:
Optional
[
str
]
=
None
class
ToolCall
(
BaseModel
):
"""Tool call response."""
id
:
Optional
[
str
]
=
None
index
:
Optional
[
int
]
=
None
type
:
Literal
[
"function"
]
=
"function"
function
:
FunctionResponse
class
ChatCompletionMessageGenericParam
(
BaseModel
):
role
:
Literal
[
"system"
,
"assistant"
,
"tool"
]
content
:
Union
[
str
,
List
[
ChatCompletionMessageContentTextPart
],
None
]
tool_call_id
:
Optional
[
str
]
=
None
name
:
Optional
[
str
]
=
None
reasoning_content
:
Optional
[
str
]
=
None
tool_calls
:
Optional
[
List
[
ToolCall
]]
=
Field
(
default
=
None
,
examples
=
[
None
])
class
ChatCompletionMessageUserParam
(
BaseModel
):
role
:
Literal
[
"user"
]
content
:
Union
[
str
,
List
[
ChatCompletionMessageContentPart
]]
ChatCompletionMessageParam
=
Union
[
ChatCompletionMessageGenericParam
,
ChatCompletionMessageUserParam
]
class
ResponseFormat
(
BaseModel
):
type
:
Literal
[
"text"
,
"json_object"
,
"json_schema"
]
json_schema
:
Optional
[
JsonSchemaResponseFormat
]
=
None
class
StructuresResponseFormat
(
BaseModel
):
begin
:
str
schema_
:
Optional
[
Dict
[
str
,
object
]]
=
Field
(
alias
=
"schema"
,
default
=
None
)
end
:
str
class
StructuralTagResponseFormat
(
BaseModel
):
type
:
Literal
[
"structural_tag"
]
structures
:
List
[
StructuresResponseFormat
]
triggers
:
List
[
str
]
class
Function
(
BaseModel
):
"""Function descriptions."""
description
:
Optional
[
str
]
=
Field
(
default
=
None
,
examples
=
[
None
])
name
:
Optional
[
str
]
=
None
parameters
:
Optional
[
object
]
=
None
strict
:
bool
=
False
class
Tool
(
BaseModel
):
"""Function wrapper."""
type
:
str
=
Field
(
default
=
"function"
,
examples
=
[
"function"
])
function
:
Function
class
ToolChoiceFuncName
(
BaseModel
):
"""The name of tool choice function."""
name
:
Optional
[
str
]
=
None
class
ToolChoice
(
BaseModel
):
"""The tool choice definition."""
function
:
ToolChoiceFuncName
type
:
Literal
[
"function"
]
=
Field
(
default
=
"function"
,
examples
=
[
"function"
])
class
ChatCompletionRequest
(
BaseModel
):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
messages
:
List
[
ChatCompletionMessageParam
]
model
:
str
frequency_penalty
:
float
=
0.0
logit_bias
:
Optional
[
Dict
[
str
,
float
]]
=
None
logprobs
:
bool
=
False
top_logprobs
:
Optional
[
int
]
=
None
max_tokens
:
Optional
[
int
]
=
Field
(
default
=
None
,
deprecated
=
"max_tokens is deprecated in favor of the max_completion_tokens field"
,
description
=
"The maximum number of tokens that can be generated in the chat completion. "
,
)
max_completion_tokens
:
Optional
[
int
]
=
Field
(
default
=
None
,
description
=
"The maximum number of completion tokens for a chat completion request, "
"including visible output tokens and reasoning tokens. Input tokens are not included. "
,
)
n
:
int
=
1
presence_penalty
:
float
=
0.0
response_format
:
Optional
[
Union
[
ResponseFormat
,
StructuralTagResponseFormat
]]
=
None
seed
:
Optional
[
int
]
=
None
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
stream
:
bool
=
False
stream_options
:
Optional
[
StreamOptions
]
=
None
temperature
:
float
=
0.7
top_p
:
float
=
1.0
user
:
Optional
[
str
]
=
None
tools
:
Optional
[
List
[
Tool
]]
=
Field
(
default
=
None
,
examples
=
[
None
])
tool_choice
:
Union
[
ToolChoice
,
Literal
[
"auto"
,
"required"
,
"none"
]]
=
Field
(
default
=
"auto"
,
examples
=
[
"none"
]
)
# noqa
@
root_validator
(
pre
=
True
)
def
set_tool_choice_default
(
cls
,
values
):
if
values
.
get
(
"tool_choice"
)
is
None
:
if
values
.
get
(
"tools"
)
is
None
:
values
[
"tool_choice"
]
=
"none"
else
:
values
[
"tool_choice"
]
=
"auto"
return
values
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k
:
int
=
-
1
min_p
:
float
=
0.0
min_tokens
:
int
=
0
regex
:
Optional
[
str
]
=
None
ebnf
:
Optional
[
str
]
=
None
repetition_penalty
:
float
=
1.0
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
no_stop_trim
:
bool
=
False
ignore_eos
:
bool
=
False
continue_final_message
:
bool
=
False
skip_special_tokens
:
bool
=
True
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
session_params
:
Optional
[
Dict
]
=
None
separate_reasoning
:
bool
=
True
stream_reasoning
:
bool
=
True
chat_template_kwargs
:
Optional
[
Dict
]
=
None
# The request id.
rid
:
Optional
[
str
]
=
None
# For PD disaggregation
bootstrap_host
:
Optional
[
str
]
=
None
bootstrap_port
:
Optional
[
int
]
=
None
bootstrap_room
:
Optional
[
int
]
=
None
# Hidden States
return_hidden_states
:
Optional
[
bool
]
=
False
class
ChatMessage
(
BaseModel
):
role
:
Optional
[
str
]
=
None
content
:
Optional
[
str
]
=
None
reasoning_content
:
Optional
[
str
]
=
None
tool_calls
:
Optional
[
List
[
ToolCall
]]
=
Field
(
default
=
None
,
examples
=
[
None
])
class
ChatCompletionResponseChoice
(
BaseModel
):
index
:
int
message
:
ChatMessage
logprobs
:
Optional
[
Union
[
LogProbs
,
ChoiceLogprobs
]]
=
None
finish_reason
:
Literal
[
"stop"
,
"length"
,
"tool_calls"
,
"content_filter"
,
"function_call"
,
"abort"
]
matched_stop
:
Union
[
None
,
int
,
str
]
=
None
hidden_states
:
Optional
[
object
]
=
None
@
model_serializer
def
_serialize
(
self
):
return
exclude_if_none
(
self
,
[
"hidden_states"
])
class
ChatCompletionResponse
(
BaseModel
):
id
:
str
object
:
str
=
"chat.completion"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
choices
:
List
[
ChatCompletionResponseChoice
]
usage
:
UsageInfo
class
DeltaMessage
(
BaseModel
):
role
:
Optional
[
str
]
=
None
content
:
Optional
[
str
]
=
None
reasoning_content
:
Optional
[
str
]
=
None
tool_calls
:
Optional
[
List
[
ToolCall
]]
=
Field
(
default
=
None
,
examples
=
[
None
])
hidden_states
:
Optional
[
object
]
=
None
@
model_serializer
def
_serialize
(
self
):
return
exclude_if_none
(
self
,
[
"hidden_states"
])
class
ChatCompletionResponseStreamChoice
(
BaseModel
):
index
:
int
delta
:
DeltaMessage
logprobs
:
Optional
[
Union
[
LogProbs
,
ChoiceLogprobs
]]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
,
"tool_calls"
,
"content_filter"
,
"function_call"
]
]
=
None
matched_stop
:
Union
[
None
,
int
,
str
]
=
None
class
ChatCompletionStreamResponse
(
BaseModel
):
id
:
str
object
:
str
=
"chat.completion.chunk"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
choices
:
List
[
ChatCompletionResponseStreamChoice
]
usage
:
Optional
[
UsageInfo
]
=
None
class
MultimodalEmbeddingInput
(
BaseModel
):
text
:
Optional
[
str
]
=
None
image
:
Optional
[
str
]
=
None
class
EmbeddingRequest
(
BaseModel
):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings/create
input
:
Union
[
List
[
int
],
List
[
List
[
int
]],
str
,
List
[
str
],
List
[
MultimodalEmbeddingInput
]
]
model
:
str
encoding_format
:
str
=
"float"
dimensions
:
int
=
None
user
:
Optional
[
str
]
=
None
# The request id.
rid
:
Optional
[
str
]
=
None
class
EmbeddingObject
(
BaseModel
):
embedding
:
List
[
float
]
index
:
int
object
:
str
=
"embedding"
class
EmbeddingResponse
(
BaseModel
):
data
:
List
[
EmbeddingObject
]
model
:
str
object
:
str
=
"list"
usage
:
Optional
[
UsageInfo
]
=
None
class
ScoringRequest
(
BaseModel
):
query
:
Optional
[
Union
[
str
,
List
[
int
]]]
=
(
None
# Query text or pre-tokenized token IDs
)
items
:
Optional
[
Union
[
str
,
List
[
str
],
List
[
List
[
int
]]]]
=
(
None
# Item text(s) or pre-tokenized token IDs
)
label_token_ids
:
Optional
[
List
[
int
]]
=
(
None
# Token IDs to compute probabilities for
)
apply_softmax
:
bool
=
False
item_first
:
bool
=
False
model
:
str
class
ScoringResponse
(
BaseModel
):
scores
:
List
[
List
[
float
]
]
# List of lists of probabilities, each in the order of label_token_ids
model
:
str
usage
:
Optional
[
UsageInfo
]
=
None
object
:
str
=
"scoring"
class
RerankResponse
(
BaseModel
):
score
:
float
document
:
str
index
:
int
meta_info
:
Optional
[
dict
]
=
None
def
exclude_if_none
(
obj
,
field_names
:
List
[
str
]):
omit_if_none_fields
=
{
k
for
k
,
v
in
obj
.
model_fields
.
items
()
if
k
in
field_names
}
return
{
k
:
v
for
k
,
v
in
obj
if
k
not
in
omit_if_none_fields
or
v
is
not
None
}
python/sglang/srt/reasoning_parser.py
View file @
72676cd6
from
typing
import
Dict
,
Tupl
e
from
typing
import
Dict
,
Optional
,
Tuple
,
Typ
e
class
StreamingParseResult
:
...
...
@@ -32,17 +32,26 @@ class BaseReasoningFormatDetector:
One-time parsing: Detects and parses reasoning sections in the provided text.
Returns both reasoning content and normal text separately.
"""
text
=
text
.
replace
(
self
.
think_start_token
,
""
).
strip
()
if
self
.
think_end_token
not
in
text
:
in_reasoning
=
self
.
_in_reasoning
or
text
.
startswith
(
self
.
think_start_token
)
if
not
in_reasoning
:
return
StreamingParseResult
(
normal_text
=
text
)
# The text is considered to be in a reasoning block.
processed_text
=
text
.
replace
(
self
.
think_start_token
,
""
).
strip
()
if
self
.
think_end_token
not
in
processed_text
:
# Assume reasoning was truncated before `</think>` token
return
StreamingParseResult
(
reasoning_text
=
text
)
return
StreamingParseResult
(
reasoning_text
=
processed_
text
)
# Extract reasoning content
splits
=
text
.
split
(
self
.
think_end_token
,
maxsplit
=
1
)
splits
=
processed_
text
.
split
(
self
.
think_end_token
,
maxsplit
=
1
)
reasoning_text
=
splits
[
0
]
text
=
splits
[
1
].
strip
()
normal_
text
=
splits
[
1
].
strip
()
return
StreamingParseResult
(
normal_text
=
text
,
reasoning_text
=
reasoning_text
)
return
StreamingParseResult
(
normal_text
=
normal_text
,
reasoning_text
=
reasoning_text
)
def
parse_streaming_increment
(
self
,
new_text
:
str
)
->
StreamingParseResult
:
"""
...
...
@@ -61,6 +70,7 @@ class BaseReasoningFormatDetector:
if
not
self
.
stripped_think_start
and
self
.
think_start_token
in
current_text
:
current_text
=
current_text
.
replace
(
self
.
think_start_token
,
""
)
self
.
stripped_think_start
=
True
self
.
_in_reasoning
=
True
# Handle end of reasoning block
if
self
.
_in_reasoning
and
self
.
think_end_token
in
current_text
:
...
...
@@ -131,11 +141,11 @@ class Qwen3Detector(BaseReasoningFormatDetector):
"""
def
__init__
(
self
,
stream_reasoning
:
bool
=
True
):
# Qwen3
is assumed to
be reasoning
until `</think>` token
# Qwen3
won't
be
in
reasoning
mode when user passes `enable_thinking=False`
super
().
__init__
(
"<think>"
,
"</think>"
,
force_reasoning
=
Tru
e
,
force_reasoning
=
Fals
e
,
stream_reasoning
=
stream_reasoning
,
)
...
...
@@ -151,12 +161,12 @@ class ReasoningParser:
If True, streams reasoning content as it arrives.
"""
DetectorMap
:
Dict
[
str
,
BaseReasoningFormatDetector
]
=
{
DetectorMap
:
Dict
[
str
,
Type
[
BaseReasoningFormatDetector
]
]
=
{
"deepseek-r1"
:
DeepSeekR1Detector
,
"qwen3"
:
Qwen3Detector
,
}
def
__init__
(
self
,
model_type
:
str
=
None
,
stream_reasoning
:
bool
=
True
):
def
__init__
(
self
,
model_type
:
Optional
[
str
]
=
None
,
stream_reasoning
:
bool
=
True
):
if
not
model_type
:
raise
ValueError
(
"Model type must be specified"
)
...
...
test/srt/openai/conftest.py
deleted
100644 → 0
View file @
02bf31ef
# sglang/test/srt/openai/conftest.py
import
os
import
socket
import
subprocess
import
sys
import
tempfile
import
time
from
contextlib
import
closing
from
typing
import
Generator
import
pytest
import
requests
from
sglang.srt.utils
import
kill_process_tree
# reuse SGLang helper
from
sglang.test.test_utils
import
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
SERVER_MODULE
=
"sglang.srt.entrypoints.openai.api_server"
DEFAULT_MODEL
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
STARTUP_TIMEOUT
=
float
(
os
.
getenv
(
"SGLANG_OPENAI_STARTUP_TIMEOUT"
,
120
))
def
_pick_free_port
()
->
int
:
with
closing
(
socket
.
socket
())
as
s
:
s
.
bind
((
"127.0.0.1"
,
0
))
return
s
.
getsockname
()[
1
]
def
_wait_until_healthy
(
proc
:
subprocess
.
Popen
,
base
:
str
,
timeout
:
float
)
->
None
:
start
=
time
.
perf_counter
()
while
time
.
perf_counter
()
-
start
<
timeout
:
if
proc
.
poll
()
is
not
None
:
# crashed
raise
RuntimeError
(
"api_server terminated prematurely"
)
try
:
if
requests
.
get
(
f
"
{
base
}
/health"
,
timeout
=
1
).
status_code
==
200
:
return
except
requests
.
RequestException
:
pass
time
.
sleep
(
0.4
)
raise
RuntimeError
(
"api_server readiness probe timed out"
)
def
launch_openai_server
(
model
:
str
=
DEFAULT_MODEL
,
**
kw
):
"""Spawn the draft OpenAI-compatible server and wait until it's ready."""
port
=
_pick_free_port
()
cmd
=
[
sys
.
executable
,
"-m"
,
SERVER_MODULE
,
"--model-path"
,
model
,
"--host"
,
"127.0.0.1"
,
"--port"
,
str
(
port
),
*
map
(
str
,
kw
.
get
(
"args"
,
[])),
]
env
=
{
**
os
.
environ
,
**
kw
.
get
(
"env"
,
{})}
# Write logs to a temp file so the child never blocks on a full pipe.
log_file
=
tempfile
.
NamedTemporaryFile
(
"w+"
,
delete
=
False
)
proc
=
subprocess
.
Popen
(
cmd
,
env
=
env
,
stdout
=
log_file
,
stderr
=
subprocess
.
STDOUT
,
text
=
True
,
)
base
=
f
"http://127.0.0.1:
{
port
}
"
try
:
_wait_until_healthy
(
proc
,
base
,
STARTUP_TIMEOUT
)
except
Exception
as
e
:
proc
.
terminate
()
proc
.
wait
(
5
)
log_file
.
seek
(
0
)
print
(
"
\n
--- api_server log ---
\n
"
,
log_file
.
read
(),
file
=
sys
.
stderr
)
raise
e
return
proc
,
base
,
log_file
@
pytest
.
fixture
(
scope
=
"session"
)
def
openai_server
()
->
Generator
[
str
,
None
,
None
]:
"""PyTest fixture that provides the server's base URL and cleans up."""
proc
,
base
,
log_file
=
launch_openai_server
()
yield
base
kill_process_tree
(
proc
.
pid
)
log_file
.
close
()
test/srt/openai/test_protocol.py
View file @
72676cd6
...
...
@@ -67,29 +67,6 @@ from sglang.srt.entrypoints.openai.protocol import (
class
TestModelCard
(
unittest
.
TestCase
):
"""Test ModelCard protocol model"""
def
test_basic_model_card_creation
(
self
):
"""Test basic model card creation with required fields"""
card
=
ModelCard
(
id
=
"test-model"
)
self
.
assertEqual
(
card
.
id
,
"test-model"
)
self
.
assertEqual
(
card
.
object
,
"model"
)
self
.
assertEqual
(
card
.
owned_by
,
"sglang"
)
self
.
assertIsInstance
(
card
.
created
,
int
)
self
.
assertIsNone
(
card
.
root
)
self
.
assertIsNone
(
card
.
max_model_len
)
def
test_model_card_with_optional_fields
(
self
):
"""Test model card with optional fields"""
card
=
ModelCard
(
id
=
"test-model"
,
root
=
"/path/to/model"
,
max_model_len
=
2048
,
created
=
1234567890
,
)
self
.
assertEqual
(
card
.
id
,
"test-model"
)
self
.
assertEqual
(
card
.
root
,
"/path/to/model"
)
self
.
assertEqual
(
card
.
max_model_len
,
2048
)
self
.
assertEqual
(
card
.
created
,
1234567890
)
def
test_model_card_serialization
(
self
):
"""Test model card JSON serialization"""
card
=
ModelCard
(
id
=
"test-model"
,
max_model_len
=
4096
)
...
...
@@ -120,53 +97,6 @@ class TestModelList(unittest.TestCase):
self
.
assertEqual
(
model_list
.
data
[
1
].
id
,
"model-2"
)
class
TestErrorResponse
(
unittest
.
TestCase
):
"""Test ErrorResponse protocol model"""
def
test_basic_error_response
(
self
):
"""Test basic error response creation"""
error
=
ErrorResponse
(
message
=
"Invalid request"
,
type
=
"BadRequestError"
,
code
=
400
)
self
.
assertEqual
(
error
.
object
,
"error"
)
self
.
assertEqual
(
error
.
message
,
"Invalid request"
)
self
.
assertEqual
(
error
.
type
,
"BadRequestError"
)
self
.
assertEqual
(
error
.
code
,
400
)
self
.
assertIsNone
(
error
.
param
)
def
test_error_response_with_param
(
self
):
"""Test error response with parameter"""
error
=
ErrorResponse
(
message
=
"Invalid temperature"
,
type
=
"ValidationError"
,
code
=
422
,
param
=
"temperature"
,
)
self
.
assertEqual
(
error
.
param
,
"temperature"
)
class
TestUsageInfo
(
unittest
.
TestCase
):
"""Test UsageInfo protocol model"""
def
test_basic_usage_info
(
self
):
"""Test basic usage info creation"""
usage
=
UsageInfo
(
prompt_tokens
=
10
,
completion_tokens
=
20
,
total_tokens
=
30
)
self
.
assertEqual
(
usage
.
prompt_tokens
,
10
)
self
.
assertEqual
(
usage
.
completion_tokens
,
20
)
self
.
assertEqual
(
usage
.
total_tokens
,
30
)
self
.
assertIsNone
(
usage
.
prompt_tokens_details
)
def
test_usage_info_with_cache_details
(
self
):
"""Test usage info with cache details"""
usage
=
UsageInfo
(
prompt_tokens
=
10
,
completion_tokens
=
20
,
total_tokens
=
30
,
prompt_tokens_details
=
{
"cached_tokens"
:
5
},
)
self
.
assertEqual
(
usage
.
prompt_tokens_details
,
{
"cached_tokens"
:
5
})
class
TestCompletionRequest
(
unittest
.
TestCase
):
"""Test CompletionRequest protocol model"""
...
...
@@ -181,30 +111,6 @@ class TestCompletionRequest(unittest.TestCase):
self
.
assertFalse
(
request
.
stream
)
# default
self
.
assertFalse
(
request
.
echo
)
# default
def
test_completion_request_with_options
(
self
):
"""Test completion request with various options"""
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
[
"Hello"
,
"world"
],
max_tokens
=
100
,
temperature
=
0.7
,
top_p
=
0.9
,
n
=
2
,
stream
=
True
,
echo
=
True
,
stop
=
[
"."
,
"!"
],
logprobs
=
5
,
)
self
.
assertEqual
(
request
.
prompt
,
[
"Hello"
,
"world"
])
self
.
assertEqual
(
request
.
max_tokens
,
100
)
self
.
assertEqual
(
request
.
temperature
,
0.7
)
self
.
assertEqual
(
request
.
top_p
,
0.9
)
self
.
assertEqual
(
request
.
n
,
2
)
self
.
assertTrue
(
request
.
stream
)
self
.
assertTrue
(
request
.
echo
)
self
.
assertEqual
(
request
.
stop
,
[
"."
,
"!"
])
self
.
assertEqual
(
request
.
logprobs
,
5
)
def
test_completion_request_sglang_extensions
(
self
):
"""Test completion request with SGLang-specific extensions"""
request
=
CompletionRequest
(
...
...
@@ -233,26 +139,6 @@ class TestCompletionRequest(unittest.TestCase):
CompletionRequest
(
model
=
"test-model"
)
# missing prompt
class
TestCompletionResponse
(
unittest
.
TestCase
):
"""Test CompletionResponse protocol model"""
def
test_basic_completion_response
(
self
):
"""Test basic completion response"""
choice
=
CompletionResponseChoice
(
index
=
0
,
text
=
"Hello world!"
,
finish_reason
=
"stop"
)
usage
=
UsageInfo
(
prompt_tokens
=
2
,
completion_tokens
=
3
,
total_tokens
=
5
)
response
=
CompletionResponse
(
id
=
"test-id"
,
model
=
"test-model"
,
choices
=
[
choice
],
usage
=
usage
)
self
.
assertEqual
(
response
.
id
,
"test-id"
)
self
.
assertEqual
(
response
.
object
,
"text_completion"
)
self
.
assertEqual
(
response
.
model
,
"test-model"
)
self
.
assertEqual
(
len
(
response
.
choices
),
1
)
self
.
assertEqual
(
response
.
choices
[
0
].
text
,
"Hello world!"
)
self
.
assertEqual
(
response
.
usage
.
total_tokens
,
5
)
class
TestChatCompletionRequest
(
unittest
.
TestCase
):
"""Test ChatCompletionRequest protocol model"""
...
...
@@ -268,48 +154,6 @@ class TestChatCompletionRequest(unittest.TestCase):
self
.
assertFalse
(
request
.
stream
)
# default
self
.
assertEqual
(
request
.
tool_choice
,
"none"
)
# default when no tools
def
test_chat_completion_with_multimodal_content
(
self
):
"""Test chat completion with multimodal content"""
messages
=
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"text"
,
"text"
:
"What's in this image?"
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
"data:image/jpeg;base64,/9j/4AAQ..."
},
},
],
}
]
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
messages
)
self
.
assertEqual
(
len
(
request
.
messages
[
0
].
content
),
2
)
self
.
assertEqual
(
request
.
messages
[
0
].
content
[
0
].
type
,
"text"
)
self
.
assertEqual
(
request
.
messages
[
0
].
content
[
1
].
type
,
"image_url"
)
def
test_chat_completion_with_tools
(
self
):
"""Test chat completion with tools"""
messages
=
[{
"role"
:
"user"
,
"content"
:
"What's the weather?"
}]
tools
=
[
{
"type"
:
"function"
,
"function"
:
{
"name"
:
"get_weather"
,
"description"
:
"Get weather information"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"location"
:
{
"type"
:
"string"
}},
},
},
}
]
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
messages
,
tools
=
tools
)
self
.
assertEqual
(
len
(
request
.
tools
),
1
)
self
.
assertEqual
(
request
.
tools
[
0
].
function
.
name
,
"get_weather"
)
self
.
assertEqual
(
request
.
tool_choice
,
"auto"
)
# default when tools present
def
test_chat_completion_tool_choice_validation
(
self
):
"""Test tool choice validation logic"""
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello"
}]
...
...
@@ -349,289 +193,6 @@ class TestChatCompletionRequest(unittest.TestCase):
self
.
assertEqual
(
request
.
chat_template_kwargs
,
{
"custom_param"
:
"value"
})
class
TestChatCompletionResponse
(
unittest
.
TestCase
):
"""Test ChatCompletionResponse protocol model"""
def
test_basic_chat_completion_response
(
self
):
"""Test basic chat completion response"""
message
=
ChatMessage
(
role
=
"assistant"
,
content
=
"Hello there!"
)
choice
=
ChatCompletionResponseChoice
(
index
=
0
,
message
=
message
,
finish_reason
=
"stop"
)
usage
=
UsageInfo
(
prompt_tokens
=
2
,
completion_tokens
=
3
,
total_tokens
=
5
)
response
=
ChatCompletionResponse
(
id
=
"test-id"
,
model
=
"test-model"
,
choices
=
[
choice
],
usage
=
usage
)
self
.
assertEqual
(
response
.
id
,
"test-id"
)
self
.
assertEqual
(
response
.
object
,
"chat.completion"
)
self
.
assertEqual
(
response
.
model
,
"test-model"
)
self
.
assertEqual
(
len
(
response
.
choices
),
1
)
self
.
assertEqual
(
response
.
choices
[
0
].
message
.
content
,
"Hello there!"
)
def
test_chat_completion_response_with_tool_calls
(
self
):
"""Test chat completion response with tool calls"""
tool_call
=
ToolCall
(
id
=
"call_123"
,
function
=
FunctionResponse
(
name
=
"get_weather"
,
arguments
=
'{"location": "San Francisco"}'
),
)
message
=
ChatMessage
(
role
=
"assistant"
,
content
=
None
,
tool_calls
=
[
tool_call
])
choice
=
ChatCompletionResponseChoice
(
index
=
0
,
message
=
message
,
finish_reason
=
"tool_calls"
)
usage
=
UsageInfo
(
prompt_tokens
=
10
,
completion_tokens
=
5
,
total_tokens
=
15
)
response
=
ChatCompletionResponse
(
id
=
"test-id"
,
model
=
"test-model"
,
choices
=
[
choice
],
usage
=
usage
)
self
.
assertEqual
(
response
.
choices
[
0
].
message
.
tool_calls
[
0
].
function
.
name
,
"get_weather"
)
self
.
assertEqual
(
response
.
choices
[
0
].
finish_reason
,
"tool_calls"
)
class
TestEmbeddingRequest
(
unittest
.
TestCase
):
"""Test EmbeddingRequest protocol model"""
def
test_basic_embedding_request
(
self
):
"""Test basic embedding request"""
request
=
EmbeddingRequest
(
model
=
"test-model"
,
input
=
"Hello world"
)
self
.
assertEqual
(
request
.
model
,
"test-model"
)
self
.
assertEqual
(
request
.
input
,
"Hello world"
)
self
.
assertEqual
(
request
.
encoding_format
,
"float"
)
# default
self
.
assertIsNone
(
request
.
dimensions
)
# default
def
test_embedding_request_with_list_input
(
self
):
"""Test embedding request with list input"""
request
=
EmbeddingRequest
(
model
=
"test-model"
,
input
=
[
"Hello"
,
"world"
],
dimensions
=
512
)
self
.
assertEqual
(
request
.
input
,
[
"Hello"
,
"world"
])
self
.
assertEqual
(
request
.
dimensions
,
512
)
def
test_multimodal_embedding_request
(
self
):
"""Test multimodal embedding request"""
multimodal_input
=
[
MultimodalEmbeddingInput
(
text
=
"Hello"
,
image
=
"base64_image_data"
),
MultimodalEmbeddingInput
(
text
=
"World"
,
image
=
None
),
]
request
=
EmbeddingRequest
(
model
=
"test-model"
,
input
=
multimodal_input
)
self
.
assertEqual
(
len
(
request
.
input
),
2
)
self
.
assertEqual
(
request
.
input
[
0
].
text
,
"Hello"
)
self
.
assertEqual
(
request
.
input
[
0
].
image
,
"base64_image_data"
)
self
.
assertEqual
(
request
.
input
[
1
].
text
,
"World"
)
self
.
assertIsNone
(
request
.
input
[
1
].
image
)
class
TestEmbeddingResponse
(
unittest
.
TestCase
):
"""Test EmbeddingResponse protocol model"""
def
test_basic_embedding_response
(
self
):
"""Test basic embedding response"""
embedding_obj
=
EmbeddingObject
(
embedding
=
[
0.1
,
0.2
,
0.3
],
index
=
0
)
usage
=
UsageInfo
(
prompt_tokens
=
3
,
total_tokens
=
3
)
response
=
EmbeddingResponse
(
data
=
[
embedding_obj
],
model
=
"test-model"
,
usage
=
usage
)
self
.
assertEqual
(
response
.
object
,
"list"
)
self
.
assertEqual
(
len
(
response
.
data
),
1
)
self
.
assertEqual
(
response
.
data
[
0
].
embedding
,
[
0.1
,
0.2
,
0.3
])
self
.
assertEqual
(
response
.
data
[
0
].
index
,
0
)
self
.
assertEqual
(
response
.
usage
.
prompt_tokens
,
3
)
class
TestScoringRequest
(
unittest
.
TestCase
):
"""Test ScoringRequest protocol model"""
def
test_basic_scoring_request
(
self
):
"""Test basic scoring request"""
request
=
ScoringRequest
(
model
=
"test-model"
,
query
=
"Hello"
,
items
=
[
"World"
,
"Earth"
]
)
self
.
assertEqual
(
request
.
model
,
"test-model"
)
self
.
assertEqual
(
request
.
query
,
"Hello"
)
self
.
assertEqual
(
request
.
items
,
[
"World"
,
"Earth"
])
self
.
assertFalse
(
request
.
apply_softmax
)
# default
self
.
assertFalse
(
request
.
item_first
)
# default
def
test_scoring_request_with_token_ids
(
self
):
"""Test scoring request with token IDs"""
request
=
ScoringRequest
(
model
=
"test-model"
,
query
=
[
1
,
2
,
3
],
items
=
[[
4
,
5
],
[
6
,
7
]],
label_token_ids
=
[
8
,
9
],
apply_softmax
=
True
,
item_first
=
True
,
)
self
.
assertEqual
(
request
.
query
,
[
1
,
2
,
3
])
self
.
assertEqual
(
request
.
items
,
[[
4
,
5
],
[
6
,
7
]])
self
.
assertEqual
(
request
.
label_token_ids
,
[
8
,
9
])
self
.
assertTrue
(
request
.
apply_softmax
)
self
.
assertTrue
(
request
.
item_first
)
class
TestScoringResponse
(
unittest
.
TestCase
):
"""Test ScoringResponse protocol model"""
def
test_basic_scoring_response
(
self
):
"""Test basic scoring response"""
response
=
ScoringResponse
(
scores
=
[[
0.1
,
0.9
],
[
0.3
,
0.7
]],
model
=
"test-model"
)
self
.
assertEqual
(
response
.
object
,
"scoring"
)
self
.
assertEqual
(
response
.
scores
,
[[
0.1
,
0.9
],
[
0.3
,
0.7
]])
self
.
assertEqual
(
response
.
model
,
"test-model"
)
self
.
assertIsNone
(
response
.
usage
)
# default
class
TestFileOperations
(
unittest
.
TestCase
):
"""Test file operation protocol models"""
def
test_file_request
(
self
):
"""Test file request model"""
file_data
=
b
"test file content"
request
=
FileRequest
(
file
=
file_data
,
purpose
=
"batch"
)
self
.
assertEqual
(
request
.
file
,
file_data
)
self
.
assertEqual
(
request
.
purpose
,
"batch"
)
def
test_file_response
(
self
):
"""Test file response model"""
response
=
FileResponse
(
id
=
"file-123"
,
bytes
=
1024
,
created_at
=
1234567890
,
filename
=
"test.jsonl"
,
purpose
=
"batch"
,
)
self
.
assertEqual
(
response
.
id
,
"file-123"
)
self
.
assertEqual
(
response
.
object
,
"file"
)
self
.
assertEqual
(
response
.
bytes
,
1024
)
self
.
assertEqual
(
response
.
filename
,
"test.jsonl"
)
def
test_file_delete_response
(
self
):
"""Test file delete response model"""
response
=
FileDeleteResponse
(
id
=
"file-123"
,
deleted
=
True
)
self
.
assertEqual
(
response
.
id
,
"file-123"
)
self
.
assertEqual
(
response
.
object
,
"file"
)
self
.
assertTrue
(
response
.
deleted
)
class
TestBatchOperations
(
unittest
.
TestCase
):
"""Test batch operation protocol models"""
def
test_batch_request
(
self
):
"""Test batch request model"""
request
=
BatchRequest
(
input_file_id
=
"file-123"
,
endpoint
=
"/v1/chat/completions"
,
completion_window
=
"24h"
,
metadata
=
{
"custom"
:
"value"
},
)
self
.
assertEqual
(
request
.
input_file_id
,
"file-123"
)
self
.
assertEqual
(
request
.
endpoint
,
"/v1/chat/completions"
)
self
.
assertEqual
(
request
.
completion_window
,
"24h"
)
self
.
assertEqual
(
request
.
metadata
,
{
"custom"
:
"value"
})
def
test_batch_response
(
self
):
"""Test batch response model"""
response
=
BatchResponse
(
id
=
"batch-123"
,
endpoint
=
"/v1/chat/completions"
,
input_file_id
=
"file-123"
,
completion_window
=
"24h"
,
created_at
=
1234567890
,
)
self
.
assertEqual
(
response
.
id
,
"batch-123"
)
self
.
assertEqual
(
response
.
object
,
"batch"
)
self
.
assertEqual
(
response
.
status
,
"validating"
)
# default
self
.
assertEqual
(
response
.
endpoint
,
"/v1/chat/completions"
)
class
TestResponseFormats
(
unittest
.
TestCase
):
"""Test response format protocol models"""
def
test_basic_response_format
(
self
):
"""Test basic response format"""
format_obj
=
ResponseFormat
(
type
=
"json_object"
)
self
.
assertEqual
(
format_obj
.
type
,
"json_object"
)
self
.
assertIsNone
(
format_obj
.
json_schema
)
def
test_json_schema_response_format
(
self
):
"""Test JSON schema response format"""
schema
=
{
"type"
:
"object"
,
"properties"
:
{
"name"
:
{
"type"
:
"string"
}}}
json_schema
=
JsonSchemaResponseFormat
(
name
=
"person_schema"
,
description
=
"Person schema"
,
schema
=
schema
)
format_obj
=
ResponseFormat
(
type
=
"json_schema"
,
json_schema
=
json_schema
)
self
.
assertEqual
(
format_obj
.
type
,
"json_schema"
)
self
.
assertEqual
(
format_obj
.
json_schema
.
name
,
"person_schema"
)
self
.
assertEqual
(
format_obj
.
json_schema
.
schema_
,
schema
)
def
test_structural_tag_response_format
(
self
):
"""Test structural tag response format"""
structures
=
[
{
"begin"
:
"<thinking>"
,
"schema_"
:
{
"type"
:
"string"
},
"end"
:
"</thinking>"
,
}
]
format_obj
=
StructuralTagResponseFormat
(
type
=
"structural_tag"
,
structures
=
structures
,
triggers
=
[
"think"
]
)
self
.
assertEqual
(
format_obj
.
type
,
"structural_tag"
)
self
.
assertEqual
(
len
(
format_obj
.
structures
),
1
)
self
.
assertEqual
(
format_obj
.
triggers
,
[
"think"
])
class
TestLogProbs
(
unittest
.
TestCase
):
"""Test LogProbs protocol models"""
def
test_basic_logprobs
(
self
):
"""Test basic LogProbs model"""
logprobs
=
LogProbs
(
text_offset
=
[
0
,
5
,
11
],
token_logprobs
=
[
-
0.1
,
-
0.2
,
-
0.3
],
tokens
=
[
"Hello"
,
" "
,
"world"
],
top_logprobs
=
[{
"Hello"
:
-
0.1
},
{
" "
:
-
0.2
},
{
"world"
:
-
0.3
}],
)
self
.
assertEqual
(
len
(
logprobs
.
tokens
),
3
)
self
.
assertEqual
(
logprobs
.
tokens
,
[
"Hello"
,
" "
,
"world"
])
self
.
assertEqual
(
logprobs
.
token_logprobs
,
[
-
0.1
,
-
0.2
,
-
0.3
])
def
test_choice_logprobs
(
self
):
"""Test ChoiceLogprobs model"""
token_logprob
=
ChatCompletionTokenLogprob
(
token
=
"Hello"
,
bytes
=
[
72
,
101
,
108
,
108
,
111
],
logprob
=-
0.1
,
top_logprobs
=
[
TopLogprob
(
token
=
"Hello"
,
bytes
=
[
72
,
101
,
108
,
108
,
111
],
logprob
=-
0.1
)
],
)
choice_logprobs
=
ChoiceLogprobs
(
content
=
[
token_logprob
])
self
.
assertEqual
(
len
(
choice_logprobs
.
content
),
1
)
self
.
assertEqual
(
choice_logprobs
.
content
[
0
].
token
,
"Hello"
)
class
TestStreamingModels
(
unittest
.
TestCase
):
"""Test streaming response models"""
def
test_stream_options
(
self
):
"""Test StreamOptions model"""
options
=
StreamOptions
(
include_usage
=
True
)
self
.
assertTrue
(
options
.
include_usage
)
def
test_chat_completion_stream_response
(
self
):
"""Test ChatCompletionStreamResponse model"""
delta
=
DeltaMessage
(
role
=
"assistant"
,
content
=
"Hello"
)
choice
=
ChatCompletionResponseStreamChoice
(
index
=
0
,
delta
=
delta
)
response
=
ChatCompletionStreamResponse
(
id
=
"test-id"
,
model
=
"test-model"
,
choices
=
[
choice
]
)
self
.
assertEqual
(
response
.
object
,
"chat.completion.chunk"
)
self
.
assertEqual
(
response
.
choices
[
0
].
delta
.
content
,
"Hello"
)
class
TestModelSerialization
(
unittest
.
TestCase
):
"""Test model serialization with hidden states"""
...
...
@@ -680,11 +241,6 @@ class TestModelSerialization(unittest.TestCase):
class
TestValidationEdgeCases
(
unittest
.
TestCase
):
"""Test edge cases and validation scenarios"""
def
test_empty_messages_validation
(
self
):
"""Test validation with empty messages"""
with
self
.
assertRaises
(
ValidationError
):
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[])
def
test_invalid_tool_choice_type
(
self
):
"""Test invalid tool choice type"""
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello"
}]
...
...
@@ -698,13 +254,6 @@ class TestValidationEdgeCases(unittest.TestCase):
with
self
.
assertRaises
(
ValidationError
):
CompletionRequest
(
model
=
"test-model"
,
prompt
=
"Hello"
,
max_tokens
=-
1
)
def
test_invalid_temperature_range
(
self
):
"""Test invalid temperature values"""
# Note: The current protocol doesn't enforce temperature range,
# but this test documents expected behavior
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
"Hello"
,
temperature
=
5.0
)
self
.
assertEqual
(
request
.
temperature
,
5.0
)
# Currently allowed
def
test_model_serialization_roundtrip
(
self
):
"""Test that models can be serialized and deserialized"""
original_request
=
ChatCompletionRequest
(
...
...
test/srt/openai/test_server.py
deleted
100644 → 0
View file @
02bf31ef
# sglang/test/srt/openai/test_server.py
import
requests
from
sglang.test.test_utils
import
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
as
MODEL_ID
def
test_health
(
openai_server
:
str
):
r
=
requests
.
get
(
f
"
{
openai_server
}
/health"
)
assert
r
.
status_code
==
200
# FastAPI returns an empty body → r.text == ""
assert
r
.
text
==
""
def
test_models_endpoint
(
openai_server
:
str
):
r
=
requests
.
get
(
f
"
{
openai_server
}
/v1/models"
)
assert
r
.
status_code
==
200
,
r
.
text
payload
=
r
.
json
()
# Basic contract
assert
"data"
in
payload
and
isinstance
(
payload
[
"data"
],
list
)
and
payload
[
"data"
]
# Validate fields of the first model card
first
=
payload
[
"data"
][
0
]
for
key
in
(
"id"
,
"root"
,
"max_model_len"
):
assert
key
in
first
,
f
"missing
{
key
}
in
{
first
}
"
# max_model_len must be positive
assert
isinstance
(
first
[
"max_model_len"
],
int
)
and
first
[
"max_model_len"
]
>
0
# The server should report the same model id we launched it with
ids
=
{
m
[
"id"
]
for
m
in
payload
[
"data"
]}
assert
MODEL_ID
in
ids
def
test_get_model_info
(
openai_server
:
str
):
r
=
requests
.
get
(
f
"
{
openai_server
}
/get_model_info"
)
assert
r
.
status_code
==
200
,
r
.
text
info
=
r
.
json
()
expected_keys
=
{
"model_path"
,
"tokenizer_path"
,
"is_generation"
}
assert
expected_keys
.
issubset
(
info
.
keys
())
# model_path must end with the one we passed on the CLI
assert
info
[
"model_path"
].
endswith
(
MODEL_ID
)
# is_generation is documented as a boolean
assert
isinstance
(
info
[
"is_generation"
],
bool
)
def
test_unknown_route_returns_404
(
openai_server
:
str
):
r
=
requests
.
get
(
f
"
{
openai_server
}
/definitely-not-a-real-route"
)
assert
r
.
status_code
==
404
test/srt/openai/test_serving_chat.py
View file @
72676cd6
...
...
@@ -57,11 +57,21 @@ class _MockTokenizerManager:
self
.
create_abort_task
=
Mock
()
class
_MockTemplateManager
:
"""Minimal mock for TemplateManager."""
def
__init__
(
self
):
self
.
chat_template_name
:
Optional
[
str
]
=
"llama-3"
self
.
jinja_template_content_format
:
Optional
[
str
]
=
None
self
.
completion_template_name
:
Optional
[
str
]
=
None
class
ServingChatTestCase
(
unittest
.
TestCase
):
# ------------- common fixtures -------------
def
setUp
(
self
):
self
.
tm
=
_MockTokenizerManager
()
self
.
chat
=
OpenAIServingChat
(
self
.
tm
)
self
.
template_manager
=
_MockTemplateManager
()
self
.
chat
=
OpenAIServingChat
(
self
.
tm
,
self
.
template_manager
)
# frequently reused requests
self
.
basic_req
=
ChatCompletionRequest
(
...
...
@@ -109,96 +119,6 @@ class ServingChatTestCase(unittest.TestCase):
self
.
assertFalse
(
adapted
.
stream
)
self
.
assertEqual
(
processed
,
self
.
basic_req
)
# # ------------- tool-call branch -------------
# def test_tool_call_request_conversion(self):
# req = ChatCompletionRequest(
# model="x",
# messages=[{"role": "user", "content": "Weather?"}],
# tools=[
# {
# "type": "function",
# "function": {
# "name": "get_weather",
# "parameters": {"type": "object", "properties": {}},
# },
# }
# ],
# tool_choice="auto",
# )
# with patch.object(
# self.chat,
# "_process_messages",
# return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
# ):
# adapted, _ = self.chat._convert_to_internal_request(req, "rid")
# self.assertEqual(adapted.rid, "rid")
# def test_tool_choice_none(self):
# req = ChatCompletionRequest(
# model="x",
# messages=[{"role": "user", "content": "Hi"}],
# tools=[{"type": "function", "function": {"name": "noop"}}],
# tool_choice="none",
# )
# with patch.object(
# self.chat,
# "_process_messages",
# return_value=("Prompt", [1, 2, 3], None, None, [], ["</s>"], None),
# ):
# adapted, _ = self.chat._convert_to_internal_request(req, "rid")
# self.assertEqual(adapted.rid, "rid")
# ------------- multimodal branch -------------
def
test_multimodal_request_with_images
(
self
):
self
.
tm
.
model_config
.
is_multimodal
=
True
req
=
ChatCompletionRequest
(
model
=
"x"
,
messages
=
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"text"
,
"text"
:
"What's in the image?"
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
"data:image/jpeg;base64,"
},
},
],
}
],
)
with
patch
.
object
(
self
.
chat
,
"_apply_jinja_template"
,
return_value
=
(
"prompt"
,
[
1
,
2
],
[
"img"
],
None
,
[],
[]),
),
patch
.
object
(
self
.
chat
,
"_apply_conversation_template"
,
return_value
=
(
"prompt"
,
[
"img"
],
None
,
[],
[]),
):
out
=
self
.
chat
.
_process_messages
(
req
,
True
)
_
,
_
,
image_data
,
*
_
=
out
self
.
assertEqual
(
image_data
,
[
"img"
])
# ------------- template handling -------------
def
test_jinja_template_processing
(
self
):
req
=
ChatCompletionRequest
(
model
=
"x"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello"
}]
)
self
.
tm
.
chat_template_name
=
None
self
.
tm
.
tokenizer
.
chat_template
=
"<jinja>"
with
patch
.
object
(
self
.
chat
,
"_apply_jinja_template"
,
return_value
=
(
"processed"
,
[
1
],
None
,
None
,
[],
[
"</s>"
]),
),
patch
(
"builtins.hasattr"
,
return_value
=
True
):
prompt
,
prompt_ids
,
*
_
=
self
.
chat
.
_process_messages
(
req
,
False
)
self
.
assertEqual
(
prompt
,
"processed"
)
self
.
assertEqual
(
prompt_ids
,
[
1
])
# ------------- sampling-params -------------
def
test_sampling_param_build
(
self
):
req
=
ChatCompletionRequest
(
...
...
test/srt/openai/test_serving_completions.py
View file @
72676cd6
...
...
@@ -5,6 +5,7 @@ Run with:
"""
import
unittest
from
typing
import
Optional
from
unittest.mock
import
AsyncMock
,
Mock
,
patch
from
sglang.srt.entrypoints.openai.protocol
import
CompletionRequest
...
...
@@ -12,6 +13,17 @@ from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompl
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
class
_MockTemplateManager
:
"""Minimal mock for TemplateManager."""
def
__init__
(
self
):
self
.
chat_template_name
:
Optional
[
str
]
=
None
self
.
jinja_template_content_format
:
Optional
[
str
]
=
None
self
.
completion_template_name
:
Optional
[
str
]
=
(
None
# Set to None to avoid template processing
)
class
ServingCompletionTestCase
(
unittest
.
TestCase
):
"""Bundle all prompt/echo tests in one TestCase."""
...
...
@@ -31,7 +43,8 @@ class ServingCompletionTestCase(unittest.TestCase):
tm
.
generate_request
=
AsyncMock
()
tm
.
create_abort_task
=
Mock
()
self
.
sc
=
OpenAIServingCompletion
(
tm
)
self
.
template_manager
=
_MockTemplateManager
()
self
.
sc
=
OpenAIServingCompletion
(
tm
,
self
.
template_manager
)
# ---------- prompt-handling ----------
def
test_single_string_prompt
(
self
):
...
...
@@ -44,20 +57,6 @@ class ServingCompletionTestCase(unittest.TestCase):
internal
,
_
=
self
.
sc
.
_convert_to_internal_request
(
req
)
self
.
assertEqual
(
internal
.
input_ids
,
[
1
,
2
,
3
,
4
])
def
test_completion_template_handling
(
self
):
req
=
CompletionRequest
(
model
=
"x"
,
prompt
=
"def f():"
,
suffix
=
"return 1"
,
max_tokens
=
100
)
with
patch
(
"sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined"
,
return_value
=
True
,
),
patch
(
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request"
,
return_value
=
"processed_prompt"
,
):
internal
,
_
=
self
.
sc
.
_convert_to_internal_request
(
req
)
self
.
assertEqual
(
internal
.
text
,
"processed_prompt"
)
# ---------- echo-handling ----------
def
test_echo_with_string_prompt_streaming
(
self
):
req
=
CompletionRequest
(
model
=
"x"
,
prompt
=
"Hello"
,
max_tokens
=
1
,
echo
=
True
)
...
...
test/srt/openai/test_serving_embedding.py
View file @
72676cd6
...
...
@@ -5,25 +5,16 @@ These tests ensure that the embedding serving implementation maintains compatibi
with the original adapter.py functionality and follows OpenAI API specifications.
"""
import
asyncio
import
json
import
time
import
unittest
import
uuid
from
typing
import
Any
,
Dict
,
List
from
unittest.mock
import
AsyncMock
,
Mock
,
patch
from
unittest.mock
import
Mock
from
fastapi
import
Request
from
fastapi.responses
import
ORJSONResponse
from
pydantic_core
import
ValidationError
from
sglang.srt.entrypoints.openai.protocol
import
(
EmbeddingObject
,
EmbeddingRequest
,
EmbeddingResponse
,
ErrorResponse
,
MultimodalEmbeddingInput
,
UsageInfo
,
)
from
sglang.srt.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
sglang.srt.managers.io_struct
import
EmbeddingReqInput
...
...
@@ -58,11 +49,22 @@ class _MockTokenizerManager:
self
.
generate_request
=
Mock
(
return_value
=
mock_generate_embedding
())
# Mock TemplateManager for embedding tests
class
_MockTemplateManager
:
def
__init__
(
self
):
self
.
chat_template_name
=
None
# None for embeddings usually
self
.
jinja_template_content_format
=
None
self
.
completion_template_name
=
None
class
ServingEmbeddingTestCase
(
unittest
.
TestCase
):
def
setUp
(
self
):
"""Set up test fixtures."""
self
.
tokenizer_manager
=
_MockTokenizerManager
()
self
.
serving_embedding
=
OpenAIServingEmbedding
(
self
.
tokenizer_manager
)
self
.
template_manager
=
_MockTemplateManager
()
self
.
serving_embedding
=
OpenAIServingEmbedding
(
self
.
tokenizer_manager
,
self
.
template_manager
)
self
.
request
=
Mock
(
spec
=
Request
)
self
.
request
.
headers
=
{}
...
...
@@ -141,132 +143,6 @@ class ServingEmbeddingTestCase(unittest.TestCase):
self
.
assertIsNone
(
adapted_request
.
image_data
[
1
])
# self.assertEqual(adapted_request.rid, "test-id")
def
test_build_single_embedding_response
(
self
):
"""Test building response for single embedding."""
ret_data
=
[
{
"embedding"
:
[
0.1
,
0.2
,
0.3
,
0.4
,
0.5
],
"meta_info"
:
{
"prompt_tokens"
:
5
},
}
]
response
=
self
.
serving_embedding
.
_build_embedding_response
(
ret_data
)
self
.
assertIsInstance
(
response
,
EmbeddingResponse
)
self
.
assertEqual
(
response
.
model
,
"test-model"
)
self
.
assertEqual
(
len
(
response
.
data
),
1
)
self
.
assertEqual
(
response
.
data
[
0
].
embedding
,
[
0.1
,
0.2
,
0.3
,
0.4
,
0.5
])
self
.
assertEqual
(
response
.
data
[
0
].
index
,
0
)
self
.
assertEqual
(
response
.
data
[
0
].
object
,
"embedding"
)
self
.
assertEqual
(
response
.
usage
.
prompt_tokens
,
5
)
self
.
assertEqual
(
response
.
usage
.
total_tokens
,
5
)
self
.
assertEqual
(
response
.
usage
.
completion_tokens
,
0
)
def
test_build_multiple_embedding_response
(
self
):
"""Test building response for multiple embeddings."""
ret_data
=
[
{
"embedding"
:
[
0.1
,
0.2
,
0.3
],
"meta_info"
:
{
"prompt_tokens"
:
3
},
},
{
"embedding"
:
[
0.4
,
0.5
,
0.6
],
"meta_info"
:
{
"prompt_tokens"
:
4
},
},
]
response
=
self
.
serving_embedding
.
_build_embedding_response
(
ret_data
)
self
.
assertIsInstance
(
response
,
EmbeddingResponse
)
self
.
assertEqual
(
len
(
response
.
data
),
2
)
self
.
assertEqual
(
response
.
data
[
0
].
embedding
,
[
0.1
,
0.2
,
0.3
])
self
.
assertEqual
(
response
.
data
[
0
].
index
,
0
)
self
.
assertEqual
(
response
.
data
[
1
].
embedding
,
[
0.4
,
0.5
,
0.6
])
self
.
assertEqual
(
response
.
data
[
1
].
index
,
1
)
self
.
assertEqual
(
response
.
usage
.
prompt_tokens
,
7
)
# 3 + 4
self
.
assertEqual
(
response
.
usage
.
total_tokens
,
7
)
def
test_handle_request_success
(
self
):
"""Test successful embedding request handling."""
async
def
run_test
():
# Mock the generate_request to return expected data
async
def
mock_generate
():
yield
{
"embedding"
:
[
0.1
,
0.2
,
0.3
,
0.4
,
0.5
],
"meta_info"
:
{
"prompt_tokens"
:
5
},
}
self
.
serving_embedding
.
tokenizer_manager
.
generate_request
=
Mock
(
return_value
=
mock_generate
()
)
response
=
await
self
.
serving_embedding
.
handle_request
(
self
.
basic_req
,
self
.
request
)
self
.
assertIsInstance
(
response
,
EmbeddingResponse
)
self
.
assertEqual
(
len
(
response
.
data
),
1
)
self
.
assertEqual
(
response
.
data
[
0
].
embedding
,
[
0.1
,
0.2
,
0.3
,
0.4
,
0.5
])
asyncio
.
run
(
run_test
())
def
test_handle_request_validation_error
(
self
):
"""Test handling request with validation error."""
async
def
run_test
():
invalid_request
=
EmbeddingRequest
(
model
=
"test-model"
,
input
=
""
)
response
=
await
self
.
serving_embedding
.
handle_request
(
invalid_request
,
self
.
request
)
self
.
assertIsInstance
(
response
,
ORJSONResponse
)
self
.
assertEqual
(
response
.
status_code
,
400
)
asyncio
.
run
(
run_test
())
def
test_handle_request_generation_error
(
self
):
"""Test handling request with generation error."""
async
def
run_test
():
# Mock generate_request to raise an error
async
def
mock_generate_error
():
raise
ValueError
(
"Generation failed"
)
yield
# This won't be reached but needed for async generator
self
.
serving_embedding
.
tokenizer_manager
.
generate_request
=
Mock
(
return_value
=
mock_generate_error
()
)
response
=
await
self
.
serving_embedding
.
handle_request
(
self
.
basic_req
,
self
.
request
)
self
.
assertIsInstance
(
response
,
ORJSONResponse
)
self
.
assertEqual
(
response
.
status_code
,
400
)
asyncio
.
run
(
run_test
())
def
test_handle_request_internal_error
(
self
):
"""Test handling request with internal server error."""
async
def
run_test
():
# Mock _convert_to_internal_request to raise an exception
with
patch
.
object
(
self
.
serving_embedding
,
"_convert_to_internal_request"
,
side_effect
=
Exception
(
"Internal error"
),
):
response
=
await
self
.
serving_embedding
.
handle_request
(
self
.
basic_req
,
self
.
request
)
self
.
assertIsInstance
(
response
,
ORJSONResponse
)
self
.
assertEqual
(
response
.
status_code
,
500
)
asyncio
.
run
(
run_test
())
if
__name__
==
"__main__"
:
unittest
.
main
(
verbosity
=
2
)
test/srt/run_suite.py
View file @
72676cd6
...
...
@@ -29,6 +29,10 @@ suites = {
TestFile
(
"models/test_reward_models.py"
,
132
),
TestFile
(
"models/test_vlm_models.py"
,
437
),
TestFile
(
"models/test_transformers_models.py"
,
320
),
TestFile
(
"openai/test_protocol.py"
,
10
),
TestFile
(
"openai/test_serving_chat.py"
,
10
),
TestFile
(
"openai/test_serving_completions.py"
,
10
),
TestFile
(
"openai/test_serving_embedding.py"
,
10
),
TestFile
(
"test_abort.py"
,
51
),
TestFile
(
"test_block_int8.py"
,
22
),
TestFile
(
"test_create_kvindices.py"
,
2
),
...
...
@@ -49,6 +53,7 @@ suites = {
TestFile
(
"test_hidden_states.py"
,
55
),
TestFile
(
"test_int8_kernel.py"
,
8
),
TestFile
(
"test_input_embeddings.py"
,
38
),
TestFile
(
"test_jinja_template_utils.py"
,
1
),
TestFile
(
"test_json_constrained.py"
,
98
),
TestFile
(
"test_large_max_new_tokens.py"
,
41
),
TestFile
(
"test_metrics.py"
,
32
),
...
...
@@ -59,14 +64,8 @@ suites = {
TestFile
(
"test_mla_fp8.py"
,
93
),
TestFile
(
"test_no_chunked_prefill.py"
,
108
),
TestFile
(
"test_no_overlap_scheduler.py"
,
234
),
TestFile
(
"test_openai_adapter.py"
,
1
),
TestFile
(
"test_openai_function_calling.py"
,
60
),
TestFile
(
"test_openai_server.py"
,
149
),
TestFile
(
"openai/test_server.py"
,
120
),
TestFile
(
"openai/test_protocol.py"
,
60
),
TestFile
(
"openai/test_serving_chat.py"
,
120
),
TestFile
(
"openai/test_serving_completions.py"
,
120
),
TestFile
(
"openai/test_serving_embedding.py"
,
120
),
TestFile
(
"test_openai_server_hidden_states.py"
,
240
),
TestFile
(
"test_penalty.py"
,
41
),
TestFile
(
"test_page_size.py"
,
60
),
...
...
test/srt/test_function_call_parser.py
View file @
72676cd6
...
...
@@ -3,6 +3,7 @@ import unittest
from
xgrammar
import
GrammarCompiler
,
TokenizerInfo
from
sglang.srt.entrypoints.openai.protocol
import
Function
,
Tool
from
sglang.srt.function_call.base_format_detector
import
BaseFormatDetector
from
sglang.srt.function_call.deepseekv3_detector
import
DeepSeekV3Detector
from
sglang.srt.function_call.llama32_detector
import
Llama32Detector
...
...
@@ -10,7 +11,6 @@ from sglang.srt.function_call.mistral_detector import MistralDetector
from
sglang.srt.function_call.pythonic_detector
import
PythonicDetector
from
sglang.srt.function_call.qwen25_detector
import
Qwen25Detector
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.openai_api.protocol
import
Function
,
Tool
from
sglang.test.test_utils
import
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
...
...
test/srt/test_
openai_adapter
.py
→
test/srt/test_
jinja_template_utils
.py
View file @
72676cd6
...
...
@@ -5,8 +5,8 @@ Unit tests for OpenAI adapter utils.
import
unittest
from
unittest.mock
import
patch
from
sglang.srt.
openai_api.
utils
import
(
detect_template_content_format
,
from
sglang.srt.
jinja_template_
utils
import
(
detect_
jinja_
template_content_format
,
process_content_for_template_format
,
)
from
sglang.test.test_utils
import
CustomTestCase
...
...
@@ -33,7 +33,7 @@ class TestTemplateContentFormatDetection(CustomTestCase):
{%- endfor %}
"""
result
=
detect_template_content_format
(
llama4_pattern
)
result
=
detect_
jinja_
template_content_format
(
llama4_pattern
)
self
.
assertEqual
(
result
,
"openai"
)
def
test_detect_deepseek_string_format
(
self
):
...
...
@@ -46,19 +46,19 @@ class TestTemplateContentFormatDetection(CustomTestCase):
{%- endfor %}
"""
result
=
detect_template_content_format
(
deepseek_pattern
)
result
=
detect_
jinja_
template_content_format
(
deepseek_pattern
)
self
.
assertEqual
(
result
,
"string"
)
def
test_detect_invalid_template
(
self
):
"""Test handling of invalid template (should default to 'string')."""
invalid_pattern
=
"{{{{ invalid jinja syntax }}}}"
result
=
detect_template_content_format
(
invalid_pattern
)
result
=
detect_
jinja_
template_content_format
(
invalid_pattern
)
self
.
assertEqual
(
result
,
"string"
)
def
test_detect_empty_template
(
self
):
"""Test handling of empty template (should default to 'string')."""
result
=
detect_template_content_format
(
""
)
result
=
detect_
jinja_
template_content_format
(
""
)
self
.
assertEqual
(
result
,
"string"
)
def
test_process_content_openai_format
(
self
):
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment