Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ffd1a26e
Unverified
Commit
ffd1a26e
authored
Jun 18, 2025
by
Jinn
Committed by
GitHub
Jun 18, 2025
Browse files
Add more refactored openai test & in CI (#7284)
parent
09ae5b20
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
576 additions
and
1059 deletions
+576
-1059
python/sglang/srt/entrypoints/openai/api_server.py
python/sglang/srt/entrypoints/openai/api_server.py
+2
-2
test/srt/openai/conftest.py
test/srt/openai/conftest.py
+4
-3
test/srt/openai/test_protocol.py
test/srt/openai/test_protocol.py
+179
-177
test/srt/openai/test_server.py
test/srt/openai/test_server.py
+41
-5
test/srt/openai/test_serving_chat.py
test/srt/openai/test_serving_chat.py
+156
-562
test/srt/openai/test_serving_completions.py
test/srt/openai/test_serving_completions.py
+68
-143
test/srt/openai/test_serving_embedding.py
test/srt/openai/test_serving_embedding.py
+121
-167
test/srt/run_suite.py
test/srt/run_suite.py
+5
-0
No files found.
python/sglang/srt/entrypoints/openai/api_server.py
View file @
ffd1a26e
...
...
@@ -36,7 +36,7 @@ from fastapi.middleware.cors import CORSMiddleware
from
fastapi.responses
import
Response
from
sglang.srt.disaggregation.utils
import
(
F
akeBootstrapHost
,
F
AKE_BOOTSTRAP_HOST
,
register_disaggregation_server
,
)
from
sglang.srt.entrypoints.engine
import
Engine
,
_launch_subprocesses
...
...
@@ -265,7 +265,7 @@ def _wait_and_warmup(
"max_new_tokens"
:
8
,
"ignore_eos"
:
True
,
},
"bootstrap_host"
:
[
F
akeBootstrapHost
]
*
server_args
.
dp_size
,
"bootstrap_host"
:
[
F
AKE_BOOTSTRAP_HOST
]
*
server_args
.
dp_size
,
# This is a hack to ensure fake transfer is enabled during prefill warmup
# ensure each dp rank has a unique bootstrap_room during prefill warmup
"bootstrap_room"
:
[
...
...
test/srt/openai/conftest.py
View file @
ffd1a26e
...
...
@@ -12,9 +12,10 @@ import pytest
import
requests
from
sglang.srt.utils
import
kill_process_tree
# reuse SGLang helper
from
sglang.test.test_utils
import
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
SERVER_MODULE
=
"sglang.srt.entrypoints.openai.api_server"
DEFAULT_MODEL
=
"dummy-model"
DEFAULT_MODEL
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
STARTUP_TIMEOUT
=
float
(
os
.
getenv
(
"SGLANG_OPENAI_STARTUP_TIMEOUT"
,
120
))
...
...
@@ -39,7 +40,7 @@ def _wait_until_healthy(proc: subprocess.Popen, base: str, timeout: float) -> No
def
launch_openai_server
(
model
:
str
=
DEFAULT_MODEL
,
**
kw
):
"""Spawn the draft OpenAI-compatible server and wait until it
’
s ready."""
"""Spawn the draft OpenAI-compatible server and wait until it
'
s ready."""
port
=
_pick_free_port
()
cmd
=
[
sys
.
executable
,
...
...
@@ -79,7 +80,7 @@ def launch_openai_server(model: str = DEFAULT_MODEL, **kw):
@
pytest
.
fixture
(
scope
=
"session"
)
def
openai_server
()
->
Generator
[
str
,
None
,
None
]:
"""PyTest fixture that provides the server
’
s base URL and cleans up."""
"""PyTest fixture that provides the server
'
s base URL and cleans up."""
proc
,
base
,
log_file
=
launch_openai_server
()
yield
base
kill_process_tree
(
proc
.
pid
)
...
...
test/srt/openai/test_protocol.py
View file @
ffd1a26e
...
...
@@ -15,9 +15,9 @@
import
json
import
time
import
unittest
from
typing
import
Dict
,
List
,
Optional
import
pytest
from
pydantic
import
ValidationError
from
sglang.srt.entrypoints.openai.protocol
import
(
...
...
@@ -64,18 +64,18 @@ from sglang.srt.entrypoints.openai.protocol import (
)
class
TestModelCard
:
class
TestModelCard
(
unittest
.
TestCase
)
:
"""Test ModelCard protocol model"""
def
test_basic_model_card_creation
(
self
):
"""Test basic model card creation with required fields"""
card
=
ModelCard
(
id
=
"test-model"
)
assert
card
.
id
==
"test-model"
assert
card
.
object
==
"model"
assert
card
.
owned_by
==
"sglang"
assert
isi
nstance
(
card
.
created
,
int
)
assert
card
.
root
is
None
assert
card
.
max_model_len
is
None
self
.
assert
Equal
(
card
.
id
,
"test-model"
)
self
.
assert
Equal
(
card
.
object
,
"model"
)
self
.
assert
Equal
(
card
.
owned_by
,
"sglang"
)
self
.
assert
IsI
nstance
(
card
.
created
,
int
)
self
.
assert
IsNone
(
card
.
root
)
self
.
assert
IsNone
(
card
.
max_model_len
)
def
test_model_card_with_optional_fields
(
self
):
"""Test model card with optional fields"""
...
...
@@ -85,28 +85,28 @@ class TestModelCard:
max_model_len
=
2048
,
created
=
1234567890
,
)
assert
card
.
id
==
"test-model"
assert
card
.
root
==
"/path/to/model"
assert
card
.
max_model_len
==
2048
assert
card
.
created
==
1234567890
self
.
assert
Equal
(
card
.
id
,
"test-model"
)
self
.
assert
Equal
(
card
.
root
,
"/path/to/model"
)
self
.
assert
Equal
(
card
.
max_model_len
,
2048
)
self
.
assert
Equal
(
card
.
created
,
1234567890
)
def
test_model_card_serialization
(
self
):
"""Test model card JSON serialization"""
card
=
ModelCard
(
id
=
"test-model"
,
max_model_len
=
4096
)
data
=
card
.
model_dump
()
assert
data
[
"id"
]
==
"test-model"
assert
data
[
"object"
]
==
"model"
assert
data
[
"max_model_len"
]
==
4096
self
.
assert
Equal
(
data
[
"id"
]
,
"test-model"
)
self
.
assert
Equal
(
data
[
"object"
]
,
"model"
)
self
.
assert
Equal
(
data
[
"max_model_len"
]
,
4096
)
class
TestModelList
:
class
TestModelList
(
unittest
.
TestCase
)
:
"""Test ModelList protocol model"""
def
test_empty_model_list
(
self
):
"""Test empty model list creation"""
model_list
=
ModelList
()
assert
model_list
.
object
==
"list"
assert
len
(
model_list
.
data
)
==
0
self
.
assert
Equal
(
model_list
.
object
,
"list"
)
self
.
assert
Equal
(
len
(
model_list
.
data
)
,
0
)
def
test_model_list_with_cards
(
self
):
"""Test model list with model cards"""
...
...
@@ -115,12 +115,12 @@ class TestModelList:
ModelCard
(
id
=
"model-2"
,
max_model_len
=
2048
),
]
model_list
=
ModelList
(
data
=
cards
)
assert
len
(
model_list
.
data
)
==
2
assert
model_list
.
data
[
0
].
id
==
"model-1"
assert
model_list
.
data
[
1
].
id
==
"model-2"
self
.
assert
Equal
(
len
(
model_list
.
data
)
,
2
)
self
.
assert
Equal
(
model_list
.
data
[
0
].
id
,
"model-1"
)
self
.
assert
Equal
(
model_list
.
data
[
1
].
id
,
"model-2"
)
class
TestErrorResponse
:
class
TestErrorResponse
(
unittest
.
TestCase
)
:
"""Test ErrorResponse protocol model"""
def
test_basic_error_response
(
self
):
...
...
@@ -128,11 +128,11 @@ class TestErrorResponse:
error
=
ErrorResponse
(
message
=
"Invalid request"
,
type
=
"BadRequestError"
,
code
=
400
)
assert
error
.
object
==
"error"
assert
error
.
message
==
"Invalid request"
assert
error
.
type
==
"BadRequestError"
assert
error
.
code
==
400
assert
error
.
param
is
None
self
.
assert
Equal
(
error
.
object
,
"error"
)
self
.
assert
Equal
(
error
.
message
,
"Invalid request"
)
self
.
assert
Equal
(
error
.
type
,
"BadRequestError"
)
self
.
assert
Equal
(
error
.
code
,
400
)
self
.
assert
IsNone
(
error
.
param
)
def
test_error_response_with_param
(
self
):
"""Test error response with parameter"""
...
...
@@ -142,19 +142,19 @@ class TestErrorResponse:
code
=
422
,
param
=
"temperature"
,
)
assert
error
.
param
==
"temperature"
self
.
assert
Equal
(
error
.
param
,
"temperature"
)
class
TestUsageInfo
:
class
TestUsageInfo
(
unittest
.
TestCase
)
:
"""Test UsageInfo protocol model"""
def
test_basic_usage_info
(
self
):
"""Test basic usage info creation"""
usage
=
UsageInfo
(
prompt_tokens
=
10
,
completion_tokens
=
20
,
total_tokens
=
30
)
assert
usage
.
prompt_tokens
==
10
assert
usage
.
completion_tokens
==
20
assert
usage
.
total_tokens
==
30
assert
usage
.
prompt_tokens_details
is
None
self
.
assert
Equal
(
usage
.
prompt_tokens
,
10
)
self
.
assert
Equal
(
usage
.
completion_tokens
,
20
)
self
.
assert
Equal
(
usage
.
total_tokens
,
30
)
self
.
assert
IsNone
(
usage
.
prompt_tokens_details
)
def
test_usage_info_with_cache_details
(
self
):
"""Test usage info with cache details"""
...
...
@@ -164,22 +164,22 @@ class TestUsageInfo:
total_tokens
=
30
,
prompt_tokens_details
=
{
"cached_tokens"
:
5
},
)
assert
usage
.
prompt_tokens_details
==
{
"cached_tokens"
:
5
}
self
.
assert
Equal
(
usage
.
prompt_tokens_details
,
{
"cached_tokens"
:
5
}
)
class
TestCompletionRequest
:
class
TestCompletionRequest
(
unittest
.
TestCase
)
:
"""Test CompletionRequest protocol model"""
def
test_basic_completion_request
(
self
):
"""Test basic completion request"""
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
"Hello world"
)
assert
request
.
model
==
"test-model"
assert
request
.
prompt
==
"Hello world"
assert
request
.
max_tokens
==
16
# default
assert
request
.
temperature
==
1.0
# default
assert
request
.
n
==
1
# default
assert
not
request
.
stream
# default
assert
not
request
.
echo
# default
self
.
assert
Equal
(
request
.
model
,
"test-model"
)
self
.
assert
Equal
(
request
.
prompt
,
"Hello world"
)
self
.
assert
Equal
(
request
.
max_tokens
,
16
)
# default
self
.
assert
Equal
(
request
.
temperature
,
1.0
)
# default
self
.
assert
Equal
(
request
.
n
,
1
)
# default
self
.
assert
False
(
request
.
stream
)
# default
self
.
assert
False
(
request
.
echo
)
# default
def
test_completion_request_with_options
(
self
):
"""Test completion request with various options"""
...
...
@@ -195,15 +195,15 @@ class TestCompletionRequest:
stop
=
[
"."
,
"!"
],
logprobs
=
5
,
)
assert
request
.
prompt
==
[
"Hello"
,
"world"
]
assert
request
.
max_tokens
==
100
assert
request
.
temperature
==
0.7
assert
request
.
top_p
==
0.9
assert
request
.
n
==
2
assert
request
.
stream
assert
request
.
echo
assert
request
.
stop
==
[
"."
,
"!"
]
assert
request
.
logprobs
==
5
self
.
assert
Equal
(
request
.
prompt
,
[
"Hello"
,
"world"
]
)
self
.
assert
Equal
(
request
.
max_tokens
,
100
)
self
.
assert
Equal
(
request
.
temperature
,
0.7
)
self
.
assert
Equal
(
request
.
top_p
,
0.9
)
self
.
assert
Equal
(
request
.
n
,
2
)
self
.
assert
True
(
request
.
stream
)
self
.
assert
True
(
request
.
echo
)
self
.
assert
Equal
(
request
.
stop
,
[
"."
,
"!"
]
)
self
.
assert
Equal
(
request
.
logprobs
,
5
)
def
test_completion_request_sglang_extensions
(
self
):
"""Test completion request with SGLang-specific extensions"""
...
...
@@ -217,23 +217,23 @@ class TestCompletionRequest:
json_schema
=
'{"type": "object"}'
,
lora_path
=
"/path/to/lora"
,
)
assert
request
.
top_k
==
50
assert
request
.
min_p
==
0.1
assert
request
.
repetition_penalty
==
1.1
assert
request
.
regex
==
r
"\d+"
assert
request
.
json_schema
==
'{"type": "object"}'
assert
request
.
lora_path
==
"/path/to/lora"
self
.
assert
Equal
(
request
.
top_k
,
50
)
self
.
assert
Equal
(
request
.
min_p
,
0.1
)
self
.
assert
Equal
(
request
.
repetition_penalty
,
1.1
)
self
.
assert
Equal
(
request
.
regex
,
r
"\d+"
)
self
.
assert
Equal
(
request
.
json_schema
,
'{"type": "object"}'
)
self
.
assert
Equal
(
request
.
lora_path
,
"/path/to/lora"
)
def
test_completion_request_validation_errors
(
self
):
"""Test completion request validation errors"""
with
pytest
.
r
aises
(
ValidationError
):
with
self
.
assertR
aises
(
ValidationError
):
CompletionRequest
()
# missing required fields
with
pytest
.
r
aises
(
ValidationError
):
with
self
.
assertR
aises
(
ValidationError
):
CompletionRequest
(
model
=
"test-model"
)
# missing prompt
class
TestCompletionResponse
:
class
TestCompletionResponse
(
unittest
.
TestCase
)
:
"""Test CompletionResponse protocol model"""
def
test_basic_completion_response
(
self
):
...
...
@@ -245,28 +245,28 @@ class TestCompletionResponse:
response
=
CompletionResponse
(
id
=
"test-id"
,
model
=
"test-model"
,
choices
=
[
choice
],
usage
=
usage
)
assert
response
.
id
==
"test-id"
assert
response
.
object
==
"text_completion"
assert
response
.
model
==
"test-model"
assert
len
(
response
.
choices
)
==
1
assert
response
.
choices
[
0
].
text
==
"Hello world!"
assert
response
.
usage
.
total_tokens
==
5
self
.
assert
Equal
(
response
.
id
,
"test-id"
)
self
.
assert
Equal
(
response
.
object
,
"text_completion"
)
self
.
assert
Equal
(
response
.
model
,
"test-model"
)
self
.
assert
Equal
(
len
(
response
.
choices
)
,
1
)
self
.
assert
Equal
(
response
.
choices
[
0
].
text
,
"Hello world!"
)
self
.
assert
Equal
(
response
.
usage
.
total_tokens
,
5
)
class
TestChatCompletionRequest
:
class
TestChatCompletionRequest
(
unittest
.
TestCase
)
:
"""Test ChatCompletionRequest protocol model"""
def
test_basic_chat_completion_request
(
self
):
"""Test basic chat completion request"""
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello"
}]
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
messages
)
assert
request
.
model
==
"test-model"
assert
len
(
request
.
messages
)
==
1
assert
request
.
messages
[
0
].
role
==
"user"
assert
request
.
messages
[
0
].
content
==
"Hello"
assert
request
.
temperature
==
0.7
# default
assert
not
request
.
stream
# default
assert
request
.
tool_choice
==
"none"
# default when no tools
self
.
assert
Equal
(
request
.
model
,
"test-model"
)
self
.
assert
Equal
(
len
(
request
.
messages
)
,
1
)
self
.
assert
Equal
(
request
.
messages
[
0
].
role
,
"user"
)
self
.
assert
Equal
(
request
.
messages
[
0
].
content
,
"Hello"
)
self
.
assert
Equal
(
request
.
temperature
,
0.7
)
# default
self
.
assert
False
(
request
.
stream
)
# default
self
.
assert
Equal
(
request
.
tool_choice
,
"none"
)
# default when no tools
def
test_chat_completion_with_multimodal_content
(
self
):
"""Test chat completion with multimodal content"""
...
...
@@ -283,9 +283,9 @@ class TestChatCompletionRequest:
}
]
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
messages
)
assert
len
(
request
.
messages
[
0
].
content
)
==
2
assert
request
.
messages
[
0
].
content
[
0
].
type
==
"text"
assert
request
.
messages
[
0
].
content
[
1
].
type
==
"image_url"
self
.
assert
Equal
(
len
(
request
.
messages
[
0
].
content
)
,
2
)
self
.
assert
Equal
(
request
.
messages
[
0
].
content
[
0
].
type
,
"text"
)
self
.
assert
Equal
(
request
.
messages
[
0
].
content
[
1
].
type
,
"image_url"
)
def
test_chat_completion_with_tools
(
self
):
"""Test chat completion with tools"""
...
...
@@ -306,9 +306,9 @@ class TestChatCompletionRequest:
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
messages
,
tools
=
tools
)
assert
len
(
request
.
tools
)
==
1
assert
request
.
tools
[
0
].
function
.
name
==
"get_weather"
assert
request
.
tool_choice
==
"auto"
# default when tools present
self
.
assert
Equal
(
len
(
request
.
tools
)
,
1
)
self
.
assert
Equal
(
request
.
tools
[
0
].
function
.
name
,
"get_weather"
)
self
.
assert
Equal
(
request
.
tool_choice
,
"auto"
)
# default when tools present
def
test_chat_completion_tool_choice_validation
(
self
):
"""Test tool choice validation logic"""
...
...
@@ -316,7 +316,7 @@ class TestChatCompletionRequest:
# No tools, tool_choice should default to "none"
request1
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
messages
)
assert
request1
.
tool_choice
==
"none"
self
.
assert
Equal
(
request1
.
tool_choice
,
"none"
)
# With tools, tool_choice should default to "auto"
tools
=
[
...
...
@@ -328,7 +328,7 @@ class TestChatCompletionRequest:
request2
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
messages
,
tools
=
tools
)
assert
request2
.
tool_choice
==
"auto"
self
.
assert
Equal
(
request2
.
tool_choice
,
"auto"
)
def
test_chat_completion_sglang_extensions
(
self
):
"""Test chat completion with SGLang extensions"""
...
...
@@ -342,14 +342,14 @@ class TestChatCompletionRequest:
stream_reasoning
=
False
,
chat_template_kwargs
=
{
"custom_param"
:
"value"
},
)
assert
request
.
top_k
==
40
assert
request
.
min_p
==
0.05
assert
not
request
.
separate_reasoning
assert
not
request
.
stream_reasoning
assert
request
.
chat_template_kwargs
==
{
"custom_param"
:
"value"
}
self
.
assert
Equal
(
request
.
top_k
,
40
)
self
.
assert
Equal
(
request
.
min_p
,
0.05
)
self
.
assert
False
(
request
.
separate_reasoning
)
self
.
assert
False
(
request
.
stream_reasoning
)
self
.
assert
Equal
(
request
.
chat_template_kwargs
,
{
"custom_param"
:
"value"
}
)
class
TestChatCompletionResponse
:
class
TestChatCompletionResponse
(
unittest
.
TestCase
)
:
"""Test ChatCompletionResponse protocol model"""
def
test_basic_chat_completion_response
(
self
):
...
...
@@ -362,11 +362,11 @@ class TestChatCompletionResponse:
response
=
ChatCompletionResponse
(
id
=
"test-id"
,
model
=
"test-model"
,
choices
=
[
choice
],
usage
=
usage
)
assert
response
.
id
==
"test-id"
assert
response
.
object
==
"chat.completion"
assert
response
.
model
==
"test-model"
assert
len
(
response
.
choices
)
==
1
assert
response
.
choices
[
0
].
message
.
content
==
"Hello there!"
self
.
assert
Equal
(
response
.
id
,
"test-id"
)
self
.
assert
Equal
(
response
.
object
,
"chat.completion"
)
self
.
assert
Equal
(
response
.
model
,
"test-model"
)
self
.
assert
Equal
(
len
(
response
.
choices
)
,
1
)
self
.
assert
Equal
(
response
.
choices
[
0
].
message
.
content
,
"Hello there!"
)
def
test_chat_completion_response_with_tool_calls
(
self
):
"""Test chat completion response with tool calls"""
...
...
@@ -384,28 +384,30 @@ class TestChatCompletionResponse:
response
=
ChatCompletionResponse
(
id
=
"test-id"
,
model
=
"test-model"
,
choices
=
[
choice
],
usage
=
usage
)
assert
response
.
choices
[
0
].
message
.
tool_calls
[
0
].
function
.
name
==
"get_weather"
assert
response
.
choices
[
0
].
finish_reason
==
"tool_calls"
self
.
assertEqual
(
response
.
choices
[
0
].
message
.
tool_calls
[
0
].
function
.
name
,
"get_weather"
)
self
.
assertEqual
(
response
.
choices
[
0
].
finish_reason
,
"tool_calls"
)
class
TestEmbeddingRequest
:
class
TestEmbeddingRequest
(
unittest
.
TestCase
)
:
"""Test EmbeddingRequest protocol model"""
def
test_basic_embedding_request
(
self
):
"""Test basic embedding request"""
request
=
EmbeddingRequest
(
model
=
"test-model"
,
input
=
"Hello world"
)
assert
request
.
model
==
"test-model"
assert
request
.
input
==
"Hello world"
assert
request
.
encoding_format
==
"float"
# default
assert
request
.
dimensions
is
None
# default
self
.
assert
Equal
(
request
.
model
,
"test-model"
)
self
.
assert
Equal
(
request
.
input
,
"Hello world"
)
self
.
assert
Equal
(
request
.
encoding_format
,
"float"
)
# default
self
.
assert
IsNone
(
request
.
dimensions
)
# default
def
test_embedding_request_with_list_input
(
self
):
"""Test embedding request with list input"""
request
=
EmbeddingRequest
(
model
=
"test-model"
,
input
=
[
"Hello"
,
"world"
],
dimensions
=
512
)
assert
request
.
input
==
[
"Hello"
,
"world"
]
assert
request
.
dimensions
==
512
self
.
assert
Equal
(
request
.
input
,
[
"Hello"
,
"world"
]
)
self
.
assert
Equal
(
request
.
dimensions
,
512
)
def
test_multimodal_embedding_request
(
self
):
"""Test multimodal embedding request"""
...
...
@@ -414,14 +416,14 @@ class TestEmbeddingRequest:
MultimodalEmbeddingInput
(
text
=
"World"
,
image
=
None
),
]
request
=
EmbeddingRequest
(
model
=
"test-model"
,
input
=
multimodal_input
)
assert
len
(
request
.
input
)
==
2
assert
request
.
input
[
0
].
text
==
"Hello"
assert
request
.
input
[
0
].
image
==
"base64_image_data"
assert
request
.
input
[
1
].
text
==
"World"
assert
request
.
input
[
1
].
image
is
None
self
.
assert
Equal
(
len
(
request
.
input
)
,
2
)
self
.
assert
Equal
(
request
.
input
[
0
].
text
,
"Hello"
)
self
.
assert
Equal
(
request
.
input
[
0
].
image
,
"base64_image_data"
)
self
.
assert
Equal
(
request
.
input
[
1
].
text
,
"World"
)
self
.
assert
IsNone
(
request
.
input
[
1
].
image
)
class
TestEmbeddingResponse
:
class
TestEmbeddingResponse
(
unittest
.
TestCase
)
:
"""Test EmbeddingResponse protocol model"""
def
test_basic_embedding_response
(
self
):
...
...
@@ -431,14 +433,14 @@ class TestEmbeddingResponse:
response
=
EmbeddingResponse
(
data
=
[
embedding_obj
],
model
=
"test-model"
,
usage
=
usage
)
assert
response
.
object
==
"list"
assert
len
(
response
.
data
)
==
1
assert
response
.
data
[
0
].
embedding
==
[
0.1
,
0.2
,
0.3
]
assert
response
.
data
[
0
].
index
==
0
assert
response
.
usage
.
prompt_tokens
==
3
self
.
assert
Equal
(
response
.
object
,
"list"
)
self
.
assert
Equal
(
len
(
response
.
data
)
,
1
)
self
.
assert
Equal
(
response
.
data
[
0
].
embedding
,
[
0.1
,
0.2
,
0.3
]
)
self
.
assert
Equal
(
response
.
data
[
0
].
index
,
0
)
self
.
assert
Equal
(
response
.
usage
.
prompt_tokens
,
3
)
class
TestScoringRequest
:
class
TestScoringRequest
(
unittest
.
TestCase
)
:
"""Test ScoringRequest protocol model"""
def
test_basic_scoring_request
(
self
):
...
...
@@ -446,11 +448,11 @@ class TestScoringRequest:
request
=
ScoringRequest
(
model
=
"test-model"
,
query
=
"Hello"
,
items
=
[
"World"
,
"Earth"
]
)
assert
request
.
model
==
"test-model"
assert
request
.
query
==
"Hello"
assert
request
.
items
==
[
"World"
,
"Earth"
]
assert
not
request
.
apply_softmax
# default
assert
not
request
.
item_first
# default
self
.
assert
Equal
(
request
.
model
,
"test-model"
)
self
.
assert
Equal
(
request
.
query
,
"Hello"
)
self
.
assert
Equal
(
request
.
items
,
[
"World"
,
"Earth"
]
)
self
.
assert
False
(
request
.
apply_softmax
)
# default
self
.
assert
False
(
request
.
item_first
)
# default
def
test_scoring_request_with_token_ids
(
self
):
"""Test scoring request with token IDs"""
...
...
@@ -462,34 +464,34 @@ class TestScoringRequest:
apply_softmax
=
True
,
item_first
=
True
,
)
assert
request
.
query
==
[
1
,
2
,
3
]
assert
request
.
items
==
[[
4
,
5
],
[
6
,
7
]]
assert
request
.
label_token_ids
==
[
8
,
9
]
assert
request
.
apply_softmax
assert
request
.
item_first
self
.
assert
Equal
(
request
.
query
,
[
1
,
2
,
3
]
)
self
.
assert
Equal
(
request
.
items
,
[[
4
,
5
],
[
6
,
7
]]
)
self
.
assert
Equal
(
request
.
label_token_ids
,
[
8
,
9
]
)
self
.
assert
True
(
request
.
apply_softmax
)
self
.
assert
True
(
request
.
item_first
)
class
TestScoringResponse
:
class
TestScoringResponse
(
unittest
.
TestCase
)
:
"""Test ScoringResponse protocol model"""
def
test_basic_scoring_response
(
self
):
"""Test basic scoring response"""
response
=
ScoringResponse
(
scores
=
[[
0.1
,
0.9
],
[
0.3
,
0.7
]],
model
=
"test-model"
)
assert
response
.
object
==
"scoring"
assert
response
.
scores
==
[[
0.1
,
0.9
],
[
0.3
,
0.7
]]
assert
response
.
model
==
"test-model"
assert
response
.
usage
is
None
# default
self
.
assert
Equal
(
response
.
object
,
"scoring"
)
self
.
assert
Equal
(
response
.
scores
,
[[
0.1
,
0.9
],
[
0.3
,
0.7
]]
)
self
.
assert
Equal
(
response
.
model
,
"test-model"
)
self
.
assert
IsNone
(
response
.
usage
)
# default
class
TestFileOperations
:
class
TestFileOperations
(
unittest
.
TestCase
)
:
"""Test file operation protocol models"""
def
test_file_request
(
self
):
"""Test file request model"""
file_data
=
b
"test file content"
request
=
FileRequest
(
file
=
file_data
,
purpose
=
"batch"
)
assert
request
.
file
==
file_data
assert
request
.
purpose
==
"batch"
self
.
assert
Equal
(
request
.
file
,
file_data
)
self
.
assert
Equal
(
request
.
purpose
,
"batch"
)
def
test_file_response
(
self
):
"""Test file response model"""
...
...
@@ -500,20 +502,20 @@ class TestFileOperations:
filename
=
"test.jsonl"
,
purpose
=
"batch"
,
)
assert
response
.
id
==
"file-123"
assert
response
.
object
==
"file"
assert
response
.
bytes
==
1024
assert
response
.
filename
==
"test.jsonl"
self
.
assert
Equal
(
response
.
id
,
"file-123"
)
self
.
assert
Equal
(
response
.
object
,
"file"
)
self
.
assert
Equal
(
response
.
bytes
,
1024
)
self
.
assert
Equal
(
response
.
filename
,
"test.jsonl"
)
def
test_file_delete_response
(
self
):
"""Test file delete response model"""
response
=
FileDeleteResponse
(
id
=
"file-123"
,
deleted
=
True
)
assert
response
.
id
==
"file-123"
assert
response
.
object
==
"file"
assert
response
.
deleted
self
.
assert
Equal
(
response
.
id
,
"file-123"
)
self
.
assert
Equal
(
response
.
object
,
"file"
)
self
.
assert
True
(
response
.
deleted
)
class
TestBatchOperations
:
class
TestBatchOperations
(
unittest
.
TestCase
)
:
"""Test batch operation protocol models"""
def
test_batch_request
(
self
):
...
...
@@ -524,10 +526,10 @@ class TestBatchOperations:
completion_window
=
"24h"
,
metadata
=
{
"custom"
:
"value"
},
)
assert
request
.
input_file_id
==
"file-123"
assert
request
.
endpoint
==
"/v1/chat/completions"
assert
request
.
completion_window
==
"24h"
assert
request
.
metadata
==
{
"custom"
:
"value"
}
self
.
assert
Equal
(
request
.
input_file_id
,
"file-123"
)
self
.
assert
Equal
(
request
.
endpoint
,
"/v1/chat/completions"
)
self
.
assert
Equal
(
request
.
completion_window
,
"24h"
)
self
.
assert
Equal
(
request
.
metadata
,
{
"custom"
:
"value"
}
)
def
test_batch_response
(
self
):
"""Test batch response model"""
...
...
@@ -538,20 +540,20 @@ class TestBatchOperations:
completion_window
=
"24h"
,
created_at
=
1234567890
,
)
assert
response
.
id
==
"batch-123"
assert
response
.
object
==
"batch"
assert
response
.
status
==
"validating"
# default
assert
response
.
endpoint
==
"/v1/chat/completions"
self
.
assert
Equal
(
response
.
id
,
"batch-123"
)
self
.
assert
Equal
(
response
.
object
,
"batch"
)
self
.
assert
Equal
(
response
.
status
,
"validating"
)
# default
self
.
assert
Equal
(
response
.
endpoint
,
"/v1/chat/completions"
)
class
TestResponseFormats
:
class
TestResponseFormats
(
unittest
.
TestCase
)
:
"""Test response format protocol models"""
def
test_basic_response_format
(
self
):
"""Test basic response format"""
format_obj
=
ResponseFormat
(
type
=
"json_object"
)
assert
format_obj
.
type
==
"json_object"
assert
format_obj
.
json_schema
is
None
self
.
assert
Equal
(
format_obj
.
type
,
"json_object"
)
self
.
assert
IsNone
(
format_obj
.
json_schema
)
def
test_json_schema_response_format
(
self
):
"""Test JSON schema response format"""
...
...
@@ -560,9 +562,9 @@ class TestResponseFormats:
name
=
"person_schema"
,
description
=
"Person schema"
,
schema
=
schema
)
format_obj
=
ResponseFormat
(
type
=
"json_schema"
,
json_schema
=
json_schema
)
assert
format_obj
.
type
==
"json_schema"
assert
format_obj
.
json_schema
.
name
==
"person_schema"
assert
format_obj
.
json_schema
.
schema_
==
schema
self
.
assert
Equal
(
format_obj
.
type
,
"json_schema"
)
self
.
assert
Equal
(
format_obj
.
json_schema
.
name
,
"person_schema"
)
self
.
assert
Equal
(
format_obj
.
json_schema
.
schema_
,
schema
)
def
test_structural_tag_response_format
(
self
):
"""Test structural tag response format"""
...
...
@@ -576,12 +578,12 @@ class TestResponseFormats:
format_obj
=
StructuralTagResponseFormat
(
type
=
"structural_tag"
,
structures
=
structures
,
triggers
=
[
"think"
]
)
assert
format_obj
.
type
==
"structural_tag"
assert
len
(
format_obj
.
structures
)
==
1
assert
format_obj
.
triggers
==
[
"think"
]
self
.
assert
Equal
(
format_obj
.
type
,
"structural_tag"
)
self
.
assert
Equal
(
len
(
format_obj
.
structures
)
,
1
)
self
.
assert
Equal
(
format_obj
.
triggers
,
[
"think"
]
)
class
TestLogProbs
:
class
TestLogProbs
(
unittest
.
TestCase
)
:
"""Test LogProbs protocol models"""
def
test_basic_logprobs
(
self
):
...
...
@@ -592,9 +594,9 @@ class TestLogProbs:
tokens
=
[
"Hello"
,
" "
,
"world"
],
top_logprobs
=
[{
"Hello"
:
-
0.1
},
{
" "
:
-
0.2
},
{
"world"
:
-
0.3
}],
)
assert
len
(
logprobs
.
tokens
)
==
3
assert
logprobs
.
tokens
==
[
"Hello"
,
" "
,
"world"
]
assert
logprobs
.
token_logprobs
==
[
-
0.1
,
-
0.2
,
-
0.3
]
self
.
assert
Equal
(
len
(
logprobs
.
tokens
)
,
3
)
self
.
assert
Equal
(
logprobs
.
tokens
,
[
"Hello"
,
" "
,
"world"
]
)
self
.
assert
Equal
(
logprobs
.
token_logprobs
,
[
-
0.1
,
-
0.2
,
-
0.3
]
)
def
test_choice_logprobs
(
self
):
"""Test ChoiceLogprobs model"""
...
...
@@ -607,17 +609,17 @@ class TestLogProbs:
],
)
choice_logprobs
=
ChoiceLogprobs
(
content
=
[
token_logprob
])
assert
len
(
choice_logprobs
.
content
)
==
1
assert
choice_logprobs
.
content
[
0
].
token
==
"Hello"
self
.
assert
Equal
(
len
(
choice_logprobs
.
content
)
,
1
)
self
.
assert
Equal
(
choice_logprobs
.
content
[
0
].
token
,
"Hello"
)
class
TestStreamingModels
:
class
TestStreamingModels
(
unittest
.
TestCase
)
:
"""Test streaming response models"""
def
test_stream_options
(
self
):
"""Test StreamOptions model"""
options
=
StreamOptions
(
include_usage
=
True
)
assert
options
.
include_usage
self
.
assert
True
(
options
.
include_usage
)
def
test_chat_completion_stream_response
(
self
):
"""Test ChatCompletionStreamResponse model"""
...
...
@@ -626,29 +628,29 @@ class TestStreamingModels:
response
=
ChatCompletionStreamResponse
(
id
=
"test-id"
,
model
=
"test-model"
,
choices
=
[
choice
]
)
assert
response
.
object
==
"chat.completion.chunk"
assert
response
.
choices
[
0
].
delta
.
content
==
"Hello"
self
.
assert
Equal
(
response
.
object
,
"chat.completion.chunk"
)
self
.
assert
Equal
(
response
.
choices
[
0
].
delta
.
content
,
"Hello"
)
class
TestValidationEdgeCases
:
class
TestValidationEdgeCases
(
unittest
.
TestCase
)
:
"""Test edge cases and validation scenarios"""
def
test_empty_messages_validation
(
self
):
"""Test validation with empty messages"""
with
pytest
.
r
aises
(
ValidationError
):
with
self
.
assertR
aises
(
ValidationError
):
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[])
def
test_invalid_tool_choice_type
(
self
):
"""Test invalid tool choice type"""
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello"
}]
with
pytest
.
r
aises
(
ValidationError
):
with
self
.
assertR
aises
(
ValidationError
):
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
messages
,
tool_choice
=
123
)
def
test_negative_token_limits
(
self
):
"""Test negative token limits"""
with
pytest
.
r
aises
(
ValidationError
):
with
self
.
assertR
aises
(
ValidationError
):
CompletionRequest
(
model
=
"test-model"
,
prompt
=
"Hello"
,
max_tokens
=-
1
)
def
test_invalid_temperature_range
(
self
):
...
...
@@ -656,7 +658,7 @@ class TestValidationEdgeCases:
# Note: The current protocol doesn't enforce temperature range,
# but this test documents expected behavior
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
"Hello"
,
temperature
=
5.0
)
assert
request
.
temperature
==
5.0
# Currently allowed
self
.
assert
Equal
(
request
.
temperature
,
5.0
)
# Currently allowed
def
test_model_serialization_roundtrip
(
self
):
"""Test that models can be serialized and deserialized"""
...
...
@@ -673,11 +675,11 @@ class TestValidationEdgeCases:
# Deserialize back
restored_request
=
ChatCompletionRequest
(
**
data
)
assert
restored_request
.
model
==
original_request
.
model
assert
restored_request
.
temperature
==
original_request
.
temperature
assert
restored_request
.
max_tokens
==
original_request
.
max_tokens
assert
len
(
restored_request
.
messages
)
==
len
(
original_request
.
messages
)
self
.
assert
Equal
(
restored_request
.
model
,
original_request
.
model
)
self
.
assert
Equal
(
restored_request
.
temperature
,
original_request
.
temperature
)
self
.
assert
Equal
(
restored_request
.
max_tokens
,
original_request
.
max_tokens
)
self
.
assert
Equal
(
len
(
restored_request
.
messages
)
,
len
(
original_request
.
messages
)
)
if
__name__
==
"__main__"
:
py
test
.
main
(
[
__file__
]
)
unit
test
.
main
(
verbosity
=
2
)
test/srt/openai/test_server.py
View file @
ffd1a26e
# sglang/test/srt/openai/test_server.py
import
pytest
import
requests
from
sglang.test.test_utils
import
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
as
MODEL_ID
def
test_health
(
openai_server
:
str
):
r
=
requests
.
get
(
f
"
{
openai_server
}
/health"
)
assert
r
.
status_code
==
200
,
r
.
text
assert
r
.
status_code
==
200
# FastAPI returns an empty body → r.text == ""
assert
r
.
text
==
""
@
pytest
.
mark
.
xfail
(
reason
=
"Endpoint skeleton not implemented yet"
)
def
test_models_endpoint
(
openai_server
:
str
):
r
=
requests
.
get
(
f
"
{
openai_server
}
/v1/models"
)
# once implemented this should be 200
assert
r
.
status_code
==
200
assert
r
.
status_code
==
200
,
r
.
text
payload
=
r
.
json
()
# Basic contract
assert
"data"
in
payload
and
isinstance
(
payload
[
"data"
],
list
)
and
payload
[
"data"
]
# Validate fields of the first model card
first
=
payload
[
"data"
][
0
]
for
key
in
(
"id"
,
"root"
,
"max_model_len"
):
assert
key
in
first
,
f
"missing
{
key
}
in
{
first
}
"
# max_model_len must be positive
assert
isinstance
(
first
[
"max_model_len"
],
int
)
and
first
[
"max_model_len"
]
>
0
# The server should report the same model id we launched it with
ids
=
{
m
[
"id"
]
for
m
in
payload
[
"data"
]}
assert
MODEL_ID
in
ids
def
test_get_model_info
(
openai_server
:
str
):
r
=
requests
.
get
(
f
"
{
openai_server
}
/get_model_info"
)
assert
r
.
status_code
==
200
,
r
.
text
info
=
r
.
json
()
expected_keys
=
{
"model_path"
,
"tokenizer_path"
,
"is_generation"
}
assert
expected_keys
.
issubset
(
info
.
keys
())
# model_path must end with the one we passed on the CLI
assert
info
[
"model_path"
].
endswith
(
MODEL_ID
)
# is_generation is documented as a boolean
assert
isinstance
(
info
[
"is_generation"
],
bool
)
def
test_unknown_route_returns_404
(
openai_server
:
str
):
r
=
requests
.
get
(
f
"
{
openai_server
}
/definitely-not-a-real-route"
)
assert
r
.
status_code
==
404
test/srt/openai/test_serving_chat.py
View file @
ffd1a26e
"""
Unit tests for the OpenAIServingChat class from serving_chat.py.
These tests ensure that the refactored implementation maintains compatibility
with the original adapter.py functionality.
Unit-tests for OpenAIServingChat — rewritten to use only the std-lib 'unittest'.
Run with either:
python tests/test_serving_chat_unit.py -v
or
python -m unittest discover -s tests -p "test_*unit.py" -v
"""
import
unittest
import
uuid
from
typing
import
Optional
from
unittest.mock
import
Mock
,
patch
import
pytest
from
fastapi
import
Request
from
sglang.srt.entrypoints.openai.protocol
import
ChatCompletionRequest
,
ErrorResponse
from
sglang.srt.entrypoints.openai.protocol
import
ChatCompletionRequest
from
sglang.srt.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
sglang.srt.managers.io_struct
import
GenerateReqInput
# Mock TokenizerManager since it may not be directly importable in tests
class
MockTokenizerManager
:
class
_MockTokenizerManager
:
"""Minimal mock that satisfies OpenAIServingChat."""
def
__init__
(
self
):
self
.
model_config
=
Mock
()
self
.
model_config
.
is_multimodal
=
False
self
.
server_args
=
Mock
()
self
.
server_args
.
enable_cache_report
=
False
self
.
server_args
.
tool_call_parser
=
"hermes"
self
.
server_args
.
reasoning_parser
=
None
self
.
chat_template_name
=
"llama-3"
self
.
model_config
=
Mock
(
is_multimodal
=
False
)
self
.
server_args
=
Mock
(
enable_cache_report
=
False
,
tool_call_parser
=
"hermes"
,
reasoning_parser
=
None
,
)
self
.
chat_template_name
:
Optional
[
str
]
=
"llama-3"
#
Mock
tokenizer
# tokenizer
stub
self
.
tokenizer
=
Mock
()
self
.
tokenizer
.
encode
=
Mock
(
return_value
=
[
1
,
2
,
3
,
4
,
5
]
)
self
.
tokenizer
.
decode
=
Mock
(
return_value
=
"Test response"
)
self
.
tokenizer
.
encode
.
return_value
=
[
1
,
2
,
3
,
4
,
5
]
self
.
tokenizer
.
decode
.
return_value
=
"Test response"
self
.
tokenizer
.
chat_template
=
None
self
.
tokenizer
.
bos_token_id
=
1
#
Mock
generate_request
method
async
def
mock_generate
():
#
async generator stub for
generate_request
async
def
_
mock_generate
():
yield
{
"text"
:
"Test response"
,
"meta_info"
:
{
...
...
@@ -50,585 +53,176 @@ class MockTokenizerManager:
"index"
:
0
,
}
self
.
generate_request
=
Mock
(
return_value
=
mock_generate
())
self
.
create_abort_task
=
Mock
(
return_value
=
None
)
@
pytest
.
fixture
def
mock_tokenizer_manager
():
"""Create a mock tokenizer manager for testing."""
return
MockTokenizerManager
()
self
.
generate_request
=
Mock
(
return_value
=
_mock_generate
())
self
.
create_abort_task
=
Mock
()
@
pytest
.
fixture
def
serving_chat
(
mock_tokenizer_manager
):
"""Create a OpenAIServingChat instance for testing."""
return
OpenAIServingChat
(
mock_tokenizer_manager
)
class
ServingChatTestCase
(
unittest
.
TestCase
):
# ------------- common fixtures -------------
def
setUp
(
self
):
self
.
tm
=
_MockTokenizerManager
()
self
.
chat
=
OpenAIServingChat
(
self
.
tm
)
@
pytest
.
fixture
def
mock_request
():
"""Create a mock FastAPI request."""
request
=
Mock
(
spec
=
Request
)
request
.
headers
=
{}
return
request
@
pytest
.
fixture
def
basic_chat_request
():
"""Create a basic chat completion request."""
return
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello, how are you?"
}],
temperature
=
0.7
,
max_tokens
=
100
,
stream
=
False
,
)
@
pytest
.
fixture
def
streaming_chat_request
():
"""Create a streaming chat completion request."""
return
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello, how are you?"
}],
temperature
=
0.7
,
max_tokens
=
100
,
stream
=
True
,
)
# frequently reused requests
self
.
basic_req
=
ChatCompletionRequest
(
model
=
"x"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hi?"
}],
temperature
=
0.7
,
max_tokens
=
100
,
stream
=
False
,
)
self
.
stream_req
=
ChatCompletionRequest
(
model
=
"x"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hi?"
}],
temperature
=
0.7
,
max_tokens
=
100
,
stream
=
True
,
)
class
TestOpenAIServingChatConversion
:
"""Test request conversion methods."""
self
.
fastapi_request
=
Mock
(
spec
=
Request
)
self
.
fastapi_request
.
headers
=
{}
def
test_convert_to_internal_request_single
(
self
,
serving_chat
,
basic_chat_request
,
mock_tokenizer_manager
):
"""Test converting single request to internal format."""
# ------------- conversion tests -------------
def
test_convert_to_internal_request_single
(
self
):
with
patch
(
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
)
as
mock_conv
:
mock_conv_instance
=
Mock
()
mock_conv_instance
.
get_prompt
.
return_value
=
"Test prompt"
mock_conv_instance
.
image_data
=
None
mock_conv_instance
.
audio_data
=
None
mock_conv_instance
.
modalities
=
[]
mock_conv_instance
.
stop_str
=
[
"</s>"
]
mock_conv
.
return_value
=
mock_conv_instance
# Mock the _process_messages method to return expected values
with
patch
.
object
(
serving_chat
,
"_process_messages"
)
as
mock_process
:
mock_process
.
return_value
=
(
"Test prompt"
,
[
1
,
2
,
3
],
None
,
None
,
[],
[
"</s>"
],
None
,
# tool_call_constraint
)
adapted_request
,
processed_request
=
(
serving_chat
.
_convert_to_internal_request
(
[
basic_chat_request
],
[
"test-id"
]
)
)
assert
isinstance
(
adapted_request
,
GenerateReqInput
)
assert
adapted_request
.
stream
==
basic_chat_request
.
stream
assert
processed_request
==
basic_chat_request
)
as
conv_mock
,
patch
.
object
(
self
.
chat
,
"_process_messages"
)
as
proc_mock
:
conv_ins
=
Mock
()
conv_ins
.
get_prompt
.
return_value
=
"Test prompt"
conv_ins
.
image_data
=
conv_ins
.
audio_data
=
None
conv_ins
.
modalities
=
[]
conv_ins
.
stop_str
=
[
"</s>"
]
conv_mock
.
return_value
=
conv_ins
proc_mock
.
return_value
=
(
"Test prompt"
,
[
1
,
2
,
3
],
None
,
None
,
[],
[
"</s>"
],
None
,
)
class
TestToolCalls
:
"""Test tool call functionality from adapter.py"""
adapted
,
processed
=
self
.
chat
.
_convert_to_internal_request
(
[
self
.
basic_req
],
[
"rid"
]
)
self
.
assertIsInstance
(
adapted
,
GenerateReqInput
)
self
.
assertFalse
(
adapted
.
stream
)
self
.
assertEqual
(
processed
,
self
.
basic_req
)
def
test_
tool
_
call
_request_conversion
(
self
,
serving_chat
):
"""Test request with
tool
call
s"""
req
uest
=
ChatCompletionRequest
(
model
=
"
test-model
"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"W
hat's the w
eather?"
}],
# -------------
tool
-
call
branch -------------
def
test_
tool
_
call
_request_conversion
(
self
):
req
=
ChatCompletionRequest
(
model
=
"
x
"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Weather?"
}],
tools
=
[
{
"type"
:
"function"
,
"function"
:
{
"name"
:
"get_weather"
,
"description"
:
"Get weather information"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"location"
:
{
"type"
:
"string"
}},
},
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{}},
},
}
],
tool_choice
=
"auto"
,
)
with
patch
.
object
(
serving_chat
,
"_process_messages"
)
as
mock_process
:
mock_process
.
return_value
=
(
"Test prompt"
,
[
1
,
2
,
3
],
None
,
None
,
[],
[
"</s>"
],
None
,
# tool_call_constraint
)
adapted_request
,
_
=
serving_chat
.
_convert_to_internal_request
(
[
request
],
[
"test-id"
]
)
assert
adapted_request
.
rid
==
"test-id"
# Tool call constraint should be processed
assert
request
.
tools
is
not
None
def
test_tool_choice_none
(
self
,
serving_chat
):
"""Test tool_choice=none disables tool calls"""
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello"
}],
tools
=
[{
"type"
:
"function"
,
"function"
:
{
"name"
:
"test_func"
}}],
with
patch
.
object
(
self
.
chat
,
"_process_messages"
,
return_value
=
(
"Prompt"
,
[
1
,
2
,
3
],
None
,
None
,
[],
[
"</s>"
],
None
),
):
adapted
,
_
=
self
.
chat
.
_convert_to_internal_request
([
req
],
[
"rid"
])
self
.
assertEqual
(
adapted
.
rid
,
"rid"
)
def
test_tool_choice_none
(
self
):
req
=
ChatCompletionRequest
(
model
=
"x"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hi"
}],
tools
=
[{
"type"
:
"function"
,
"function"
:
{
"name"
:
"noop"
}}],
tool_choice
=
"none"
,
)
with
patch
.
object
(
serving_chat
,
"_process_messages"
)
as
mock_process
:
mock_process
.
return_value
=
(
"Test prompt"
,
[
1
,
2
,
3
],
None
,
None
,
[],
[
"</s>"
],
None
,
# tool_call_constraint
)
adapted_request
,
_
=
serving_chat
.
_convert_to_internal_request
(
[
request
],
[
"test-id"
]
)
# Tools should not be processed when tool_choice is "none"
assert
adapted_request
.
rid
==
"test-id"
def
test_tool_call_response_processing
(
self
,
serving_chat
):
"""Test processing tool calls in response"""
mock_ret_item
=
{
"text"
:
'{"name": "get_weather", "parameters": {"location": "Paris"}}'
,
"meta_info"
:
{
"output_token_logprobs"
:
[],
"output_top_logprobs"
:
None
,
},
}
tools
=
[
{
"type"
:
"function"
,
"function"
:
{
"name"
:
"get_weather"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"location"
:
{
"type"
:
"string"
}},
},
},
}
]
finish_reason
=
{
"type"
:
"stop"
,
"matched"
:
None
}
# Mock FunctionCallParser
with
patch
(
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
)
as
mock_parser_class
:
mock_parser
=
Mock
()
mock_parser
.
has_tool_call
.
return_value
=
True
# Create proper mock tool call object
mock_tool_call
=
Mock
()
mock_tool_call
.
name
=
"get_weather"
mock_tool_call
.
parameters
=
'{"location": "Paris"}'
mock_parser
.
parse_non_stream
.
return_value
=
(
""
,
[
mock_tool_call
])
mock_parser_class
.
return_value
=
mock_parser
tool_calls
,
text
,
updated_finish_reason
=
serving_chat
.
_process_tool_calls
(
mock_ret_item
[
"text"
],
tools
,
"hermes"
,
finish_reason
)
assert
tool_calls
is
not
None
assert
len
(
tool_calls
)
==
1
assert
updated_finish_reason
[
"type"
]
==
"tool_calls"
class
TestMultimodalContent
:
"""Test multimodal content handling from adapter.py"""
def
test_multimodal_request_with_images
(
self
,
serving_chat
):
"""Test request with image content"""
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
with
patch
.
object
(
self
.
chat
,
"_process_messages"
,
return_value
=
(
"Prompt"
,
[
1
,
2
,
3
],
None
,
None
,
[],
[
"</s>"
],
None
),
):
adapted
,
_
=
self
.
chat
.
_convert_to_internal_request
([
req
],
[
"rid"
])
self
.
assertEqual
(
adapted
.
rid
,
"rid"
)
# ------------- multimodal branch -------------
def
test_multimodal_request_with_images
(
self
):
self
.
tm
.
model_config
.
is_multimodal
=
True
req
=
ChatCompletionRequest
(
model
=
"x"
,
messages
=
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"text"
,
"text"
:
"What's in th
is
image?"
},
{
"type"
:
"text"
,
"text"
:
"What's in th
e
image?"
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
"data:image/jpeg;base64,
...
"
},
"image_url"
:
{
"url"
:
"data:image/jpeg;base64,"
},
},
],
}
],
)
# Set multimodal mode
serving_chat
.
tokenizer_manager
.
model_config
.
is_multimodal
=
True
with
patch
.
object
(
serving_chat
,
"_apply_jinja_template"
)
as
mock_apply
:
mock_apply
.
return_value
=
(
"prompt"
,
[
1
,
2
,
3
],
[
"image_data"
],
None
,
[],
[],
)
with
patch
.
object
(
serving_chat
,
"_apply_conversation_template"
)
as
mock_conv
:
mock_conv
.
return_value
=
(
"prompt"
,
[
"image_data"
],
None
,
[],
[])
(
prompt
,
prompt_ids
,
image_data
,
audio_data
,
modalities
,
stop
,
tool_call_constraint
,
)
=
serving_chat
.
_process_messages
(
request
,
True
)
assert
image_data
==
[
"image_data"
]
assert
prompt
==
"prompt"
def
test_multimodal_request_with_audio
(
self
,
serving_chat
):
"""Test request with audio content"""
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"text"
,
"text"
:
"Transcribe this audio"
},
{
"type"
:
"audio_url"
,
"audio_url"
:
{
"url"
:
"data:audio/wav;base64,UklGR..."
},
},
],
}
],
with
patch
.
object
(
self
.
chat
,
"_apply_jinja_template"
,
return_value
=
(
"prompt"
,
[
1
,
2
],
[
"img"
],
None
,
[],
[]),
),
patch
.
object
(
self
.
chat
,
"_apply_conversation_template"
,
return_value
=
(
"prompt"
,
[
"img"
],
None
,
[],
[]),
):
out
=
self
.
chat
.
_process_messages
(
req
,
True
)
_
,
_
,
image_data
,
*
_
=
out
self
.
assertEqual
(
image_data
,
[
"img"
])
# ------------- template handling -------------
def
test_jinja_template_processing
(
self
):
req
=
ChatCompletionRequest
(
model
=
"x"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello"
}]
)
serving_chat
.
tokenizer_manager
.
model_config
.
is_multimodal
=
True
with
patch
.
object
(
serving_chat
,
"_apply_jinja_template"
)
as
mock_apply
:
mock_apply
.
return_value
=
(
"prompt"
,
[
1
,
2
,
3
],
None
,
[
"audio_data"
],
[
"audio"
],
[],
)
with
patch
.
object
(
serving_chat
,
"_apply_conversation_template"
)
as
mock_conv
:
mock_conv
.
return_value
=
(
"prompt"
,
None
,
[
"audio_data"
],
[
"audio"
],
[])
(
prompt
,
prompt_ids
,
image_data
,
audio_data
,
modalities
,
stop
,
tool_call_constraint
,
)
=
serving_chat
.
_process_messages
(
request
,
True
)
assert
audio_data
==
[
"audio_data"
]
assert
modalities
==
[
"audio"
]
class
TestTemplateHandling
:
"""Test chat template handling from adapter.py"""
def
test_jinja_template_processing
(
self
,
serving_chat
):
"""Test Jinja template processing"""
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello"
}]
)
# Mock the template attribute directly
serving_chat
.
tokenizer_manager
.
chat_template_name
=
None
serving_chat
.
tokenizer_manager
.
tokenizer
.
chat_template
=
"<jinja_template>"
with
patch
.
object
(
serving_chat
,
"_apply_jinja_template"
)
as
mock_apply
:
mock_apply
.
return_value
=
(
"processed_prompt"
,
[
1
,
2
,
3
],
None
,
None
,
[],
[
"</s>"
],
)
# Mock hasattr to simulate the None check
with
patch
(
"builtins.hasattr"
)
as
mock_hasattr
:
mock_hasattr
.
return_value
=
True
(
prompt
,
prompt_ids
,
image_data
,
audio_data
,
modalities
,
stop
,
tool_call_constraint
,
)
=
serving_chat
.
_process_messages
(
request
,
False
)
assert
prompt
==
"processed_prompt"
assert
prompt_ids
==
[
1
,
2
,
3
]
def
test_conversation_template_processing
(
self
,
serving_chat
):
"""Test conversation template processing"""
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello"
}]
)
serving_chat
.
tokenizer_manager
.
chat_template_name
=
"llama-3"
with
patch
.
object
(
serving_chat
,
"_apply_conversation_template"
)
as
mock_apply
:
mock_apply
.
return_value
=
(
"conv_prompt"
,
None
,
None
,
[],
[
"</s>"
])
(
prompt
,
prompt_ids
,
image_data
,
audio_data
,
modalities
,
stop
,
tool_call_constraint
,
)
=
serving_chat
.
_process_messages
(
request
,
False
)
assert
prompt
==
"conv_prompt"
assert
stop
==
[
"</s>"
]
def
test_continue_final_message
(
self
,
serving_chat
):
"""Test continue_final_message functionality"""
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[
{
"role"
:
"user"
,
"content"
:
"Hello"
},
{
"role"
:
"assistant"
,
"content"
:
"Hi there"
},
],
continue_final_message
=
True
,
)
with
patch
.
object
(
serving_chat
,
"_apply_conversation_template"
)
as
mock_apply
:
mock_apply
.
return_value
=
(
"Hi there"
,
None
,
None
,
[],
[
"</s>"
])
(
prompt
,
prompt_ids
,
image_data
,
audio_data
,
modalities
,
stop
,
tool_call_constraint
,
)
=
serving_chat
.
_process_messages
(
request
,
False
)
# Should handle continue_final_message properly
assert
prompt
==
"Hi there"
class
TestReasoningContent
:
"""Test reasoning content separation from adapter.py"""
def
test_reasoning_content_request
(
self
,
serving_chat
):
"""Test request with reasoning content separation"""
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Solve this math problem"
}],
separate_reasoning
=
True
,
stream_reasoning
=
False
,
)
with
patch
.
object
(
serving_chat
,
"_process_messages"
)
as
mock_process
:
mock_process
.
return_value
=
(
"Test prompt"
,
[
1
,
2
,
3
],
None
,
None
,
[],
[
"</s>"
],
None
,
# tool_call_constraint
)
adapted_request
,
_
=
serving_chat
.
_convert_to_internal_request
(
[
request
],
[
"test-id"
]
)
assert
adapted_request
.
rid
==
"test-id"
assert
request
.
separate_reasoning
==
True
def
test_reasoning_content_response
(
self
,
serving_chat
):
"""Test reasoning content in response"""
mock_ret_item
=
{
"text"
:
"<thinking>This is reasoning</thinking>Answer: 42"
,
"meta_info"
:
{
"output_token_logprobs"
:
[],
"output_top_logprobs"
:
None
,
},
}
# Mock ReasoningParser
with
patch
(
"sglang.srt.entrypoints.openai.serving_chat.ReasoningParser"
)
as
mock_parser_class
:
mock_parser
=
Mock
()
mock_parser
.
parse_non_stream
.
return_value
=
(
"This is reasoning"
,
"Answer: 42"
,
)
mock_parser_class
.
return_value
=
mock_parser
choice_logprobs
=
None
reasoning_text
=
None
text
=
mock_ret_item
[
"text"
]
# Simulate reasoning processing
enable_thinking
=
True
if
enable_thinking
:
parser
=
mock_parser_class
(
model_type
=
"test"
,
stream_reasoning
=
False
)
reasoning_text
,
text
=
parser
.
parse_non_stream
(
text
)
assert
reasoning_text
==
"This is reasoning"
assert
text
==
"Answer: 42"
class
TestSamplingParams
:
"""Test sampling parameter handling from adapter.py"""
def
test_all_sampling_parameters
(
self
,
serving_chat
):
"""Test all sampling parameters are properly handled"""
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello"
}],
self
.
tm
.
chat_template_name
=
None
self
.
tm
.
tokenizer
.
chat_template
=
"<jinja>"
with
patch
.
object
(
self
.
chat
,
"_apply_jinja_template"
,
return_value
=
(
"processed"
,
[
1
],
None
,
None
,
[],
[
"</s>"
]),
),
patch
(
"builtins.hasattr"
,
return_value
=
True
):
prompt
,
prompt_ids
,
*
_
=
self
.
chat
.
_process_messages
(
req
,
False
)
self
.
assertEqual
(
prompt
,
"processed"
)
self
.
assertEqual
(
prompt_ids
,
[
1
])
# ------------- sampling-params -------------
def
test_sampling_param_build
(
self
):
req
=
ChatCompletionRequest
(
model
=
"x"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hi"
}],
temperature
=
0.8
,
max_tokens
=
150
,
max_completion_tokens
=
200
,
min_tokens
=
5
,
top_p
=
0.9
,
top_k
=
50
,
min_p
=
0.1
,
presence_penalty
=
0.1
,
frequency_penalty
=
0.2
,
repetition_penalty
=
1.1
,
stop
=
[
"<|endoftext|>"
],
stop_token_ids
=
[
13
,
14
],
regex
=
r
"\d+"
,
ebnf
=
"<expr> ::= <number>"
,
n
=
2
,
no_stop_trim
=
True
,
ignore_eos
=
True
,
skip_special_tokens
=
False
,
logit_bias
=
{
"1"
:
0.5
,
"2"
:
-
0.3
},
)
with
patch
.
object
(
serving_chat
,
"_process_messages"
)
as
mock_process
:
mock_process
.
return_value
=
(
"Test prompt"
,
[
1
,
2
,
3
],
None
,
None
,
[],
[
"</s>"
],
None
,
# tool_call_constraint
)
sampling_params
=
serving_chat
.
_build_sampling_params
(
request
,
[
"</s>"
],
None
)
# Verify all parameters
assert
sampling_params
[
"temperature"
]
==
0.8
assert
sampling_params
[
"max_new_tokens"
]
==
150
assert
sampling_params
[
"min_new_tokens"
]
==
5
assert
sampling_params
[
"top_p"
]
==
0.9
assert
sampling_params
[
"top_k"
]
==
50
assert
sampling_params
[
"min_p"
]
==
0.1
assert
sampling_params
[
"presence_penalty"
]
==
0.1
assert
sampling_params
[
"frequency_penalty"
]
==
0.2
assert
sampling_params
[
"repetition_penalty"
]
==
1.1
assert
sampling_params
[
"stop"
]
==
[
"</s>"
]
assert
sampling_params
[
"logit_bias"
]
==
{
"1"
:
0.5
,
"2"
:
-
0.3
}
def
test_response_format_json_schema
(
self
,
serving_chat
):
"""Test response format with JSON schema"""
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Generate JSON"
}],
response_format
=
{
"type"
:
"json_schema"
,
"json_schema"
:
{
"name"
:
"response"
,
"schema"
:
{
"type"
:
"object"
,
"properties"
:
{
"answer"
:
{
"type"
:
"string"
}},
},
},
},
)
with
patch
.
object
(
serving_chat
,
"_process_messages"
)
as
mock_process
:
mock_process
.
return_value
=
(
"Test prompt"
,
[
1
,
2
,
3
],
None
,
None
,
[],
[
"</s>"
],
None
,
# tool_call_constraint
)
sampling_params
=
serving_chat
.
_build_sampling_params
(
request
,
[
"</s>"
],
None
)
assert
"json_schema"
in
sampling_params
assert
'"type": "object"'
in
sampling_params
[
"json_schema"
]
def
test_response_format_json_object
(
self
,
serving_chat
):
"""Test response format with JSON object"""
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Generate JSON"
}],
response_format
=
{
"type"
:
"json_object"
},
stop
=
[
"</s>"
],
)
with
patch
.
object
(
serving_chat
,
"_process_messages"
)
as
mock_process
:
mock_process
.
return_value
=
(
"Test prompt"
,
[
1
,
2
,
3
],
None
,
None
,
[],
[
"</s>"
],
None
,
# tool_call_constraint
)
sampling_params
=
serving_chat
.
_build_sampling_params
(
request
,
[
"</s>"
],
None
)
assert
sampling_params
[
"json_schema"
]
==
'{"type": "object"}'
with
patch
.
object
(
self
.
chat
,
"_process_messages"
,
return_value
=
(
"Prompt"
,
[
1
],
None
,
None
,
[],
[
"</s>"
],
None
),
):
params
=
self
.
chat
.
_build_sampling_params
(
req
,
[
"</s>"
],
None
)
self
.
assertEqual
(
params
[
"temperature"
],
0.8
)
self
.
assertEqual
(
params
[
"max_new_tokens"
],
150
)
self
.
assertEqual
(
params
[
"min_new_tokens"
],
5
)
self
.
assertEqual
(
params
[
"stop"
],
[
"</s>"
])
if
__name__
==
"__main__"
:
unittest
.
main
(
verbosity
=
2
)
test/srt/openai/test_serving_completions.py
View file @
ffd1a26e
"""
Tests for the refactored completions serving handler
Unit-tests for the refactored completions-serving handler (no pytest).
Run with:
python -m unittest tests.test_serving_completions_unit -v
"""
import
unittest
from
unittest.mock
import
AsyncMock
,
Mock
,
patch
import
pytest
from
sglang.srt.entrypoints.openai.protocol
import
(
CompletionRequest
,
CompletionResponse
,
CompletionResponseChoice
,
CompletionStreamResponse
,
ErrorResponse
,
)
from
sglang.srt.entrypoints.openai.protocol
import
CompletionRequest
from
sglang.srt.entrypoints.openai.serving_completions
import
OpenAIServingCompletion
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
@
pytest
.
fixture
def
mock_tokenizer_manager
():
"""Create a mock tokenizer manager"""
manager
=
Mock
(
spec
=
TokenizerManager
)
# Mock tokenizer
manager
.
tokenizer
=
Mock
()
manager
.
tokenizer
.
encode
=
Mock
(
return_value
=
[
1
,
2
,
3
,
4
])
manager
.
tokenizer
.
decode
=
Mock
(
return_value
=
"decoded text"
)
manager
.
tokenizer
.
bos_token_id
=
1
# Mock model config
manager
.
model_config
=
Mock
()
manager
.
model_config
.
is_multimodal
=
False
# Mock server args
manager
.
server_args
=
Mock
()
manager
.
server_args
.
enable_cache_report
=
False
class
ServingCompletionTestCase
(
unittest
.
TestCase
):
"""Bundle all prompt/echo tests in one TestCase."""
# Mock generation
manager
.
generate_request
=
AsyncMock
()
manager
.
create_abort_task
=
Mock
(
return_value
=
None
)
# ---------- shared test fixtures ----------
def
setUp
(
self
):
# build the mock TokenizerManager once for every test
tm
=
Mock
(
spec
=
TokenizerManager
)
return
manager
tm
.
tokenizer
=
Mock
()
tm
.
tokenizer
.
encode
.
return_value
=
[
1
,
2
,
3
,
4
]
tm
.
tokenizer
.
decode
.
return_value
=
"decoded text"
tm
.
tokenizer
.
bos_token_id
=
1
tm
.
model_config
=
Mock
(
is_multimodal
=
False
)
tm
.
server_args
=
Mock
(
enable_cache_report
=
False
)
@
pytest
.
fixture
def
serving_completion
(
mock_tokenizer_manager
):
"""Create a OpenAIServingCompletion instance"""
return
OpenAIServingCompletion
(
mock_tokenizer_manager
)
tm
.
generate_request
=
AsyncMock
()
tm
.
create_abort_task
=
Mock
()
self
.
sc
=
OpenAIServingCompletion
(
tm
)
class
TestPromptHandling
:
"""Test different prompt types and formats from adapter.py"""
def
test_single_string_prompt
(
self
,
serving_completion
):
"""Test handling single string prompt"""
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
"Hello world"
,
max_tokens
=
100
)
adapted_request
,
_
=
serving_completion
.
_convert_to_internal_request
(
[
request
],
[
"test-id"
]
)
assert
adapted_request
.
text
==
"Hello world"
def
test_single_token_ids_prompt
(
self
,
serving_completion
):
"""Test handling single token IDs prompt"""
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
[
1
,
2
,
3
,
4
],
max_tokens
=
100
)
adapted_request
,
_
=
serving_completion
.
_convert_to_internal_request
(
[
request
],
[
"test-id"
]
)
# ---------- prompt-handling ----------
def
test_single_string_prompt
(
self
):
req
=
CompletionRequest
(
model
=
"x"
,
prompt
=
"Hello world"
,
max_tokens
=
100
)
internal
,
_
=
self
.
sc
.
_convert_to_internal_request
([
req
],
[
"id"
])
self
.
assertEqual
(
internal
.
text
,
"Hello world"
)
assert
adapted_request
.
input_ids
==
[
1
,
2
,
3
,
4
]
def
test_single_token_ids_prompt
(
self
):
req
=
CompletionRequest
(
model
=
"x"
,
prompt
=
[
1
,
2
,
3
,
4
],
max_tokens
=
100
)
internal
,
_
=
self
.
sc
.
_convert_to_internal_request
([
req
],
[
"id"
])
self
.
assertEqual
(
internal
.
input_ids
,
[
1
,
2
,
3
,
4
])
def
test_completion_template_handling
(
self
,
serving_completion
):
"""Test completion template processing"""
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
"def hello():"
,
suffix
=
"return 'world'"
,
max_tokens
=
100
,
def
test_completion_template_handling
(
self
):
req
=
CompletionRequest
(
model
=
"x"
,
prompt
=
"def f():"
,
suffix
=
"return 1"
,
max_tokens
=
100
)
with
patch
(
"sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined"
,
return_value
=
True
,
),
patch
(
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request"
,
return_value
=
"processed_prompt"
,
):
with
patch
(
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request"
,
return_value
=
"processed_prompt"
,
):
adapted_request
,
_
=
serving_completion
.
_convert_to_internal_request
(
[
request
],
[
"test-id"
]
)
assert
adapted_request
.
text
==
"processed_prompt"
internal
,
_
=
self
.
sc
.
_convert_to_internal_request
([
req
],
[
"id"
])
self
.
assertEqual
(
internal
.
text
,
"processed_prompt"
)
# ---------- echo-handling ----------
def
test_echo_with_string_prompt_streaming
(
self
):
req
=
CompletionRequest
(
model
=
"x"
,
prompt
=
"Hello"
,
max_tokens
=
1
,
echo
=
True
)
self
.
assertEqual
(
self
.
sc
.
_get_echo_text
(
req
,
0
),
"Hello"
)
class
TestEchoHandling
:
"""Test echo functionality from adapter.py"""
def
test_echo_with_string_prompt_streaming
(
self
,
serving_completion
):
"""Test echo handling with string prompt in streaming"""
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
"Hello"
,
max_tokens
=
100
,
echo
=
True
def
test_echo_with_list_of_strings_streaming
(
self
):
req
=
CompletionRequest
(
model
=
"x"
,
prompt
=
[
"A"
,
"B"
],
max_tokens
=
1
,
echo
=
True
,
n
=
1
)
self
.
assertEqual
(
self
.
sc
.
_get_echo_text
(
req
,
0
),
"A"
)
self
.
assertEqual
(
self
.
sc
.
_get_echo_text
(
req
,
1
),
"B"
)
# Test _get_echo_text method
echo_text
=
serving_completion
.
_get_echo_text
(
request
,
0
)
assert
echo_text
==
"Hello"
def
test_echo_with_list_of_strings_streaming
(
self
,
serving_completion
):
"""Test echo handling with list of strings in streaming"""
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
[
"Hello"
,
"World"
],
max_tokens
=
100
,
echo
=
True
,
n
=
1
,
)
def
test_echo_with_token_ids_streaming
(
self
):
req
=
CompletionRequest
(
model
=
"x"
,
prompt
=
[
1
,
2
,
3
],
max_tokens
=
1
,
echo
=
True
)
self
.
sc
.
tokenizer_manager
.
tokenizer
.
decode
.
return_value
=
"decoded_prompt"
self
.
assertEqual
(
self
.
sc
.
_get_echo_text
(
req
,
0
),
"decoded_prompt"
)
echo_text
=
serving_completion
.
_get_echo_text
(
request
,
0
)
assert
echo_text
==
"Hello"
echo_text
=
serving_completion
.
_get_echo_text
(
request
,
1
)
assert
echo_text
==
"World"
def
test_echo_with_token_ids_streaming
(
self
,
serving_completion
):
"""Test echo handling with token IDs in streaming"""
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
[
1
,
2
,
3
],
max_tokens
=
100
,
echo
=
True
def
test_echo_with_multiple_token_ids_streaming
(
self
):
req
=
CompletionRequest
(
model
=
"x"
,
prompt
=
[[
1
,
2
],
[
3
,
4
]],
max_tokens
=
1
,
echo
=
True
,
n
=
1
)
self
.
sc
.
tokenizer_manager
.
tokenizer
.
decode
.
return_value
=
"decoded"
self
.
assertEqual
(
self
.
sc
.
_get_echo_text
(
req
,
0
),
"decoded"
)
serving_completion
.
tokenizer_manager
.
tokenizer
.
decode
.
return_value
=
(
"decoded_prompt"
)
echo_text
=
serving_completion
.
_get_echo_text
(
request
,
0
)
assert
echo_text
==
"decoded_prompt"
def
test_prepare_echo_prompts_non_streaming
(
self
):
# single string
req
=
CompletionRequest
(
model
=
"x"
,
prompt
=
"Hi"
,
echo
=
True
)
self
.
assertEqual
(
self
.
sc
.
_prepare_echo_prompts
(
req
),
[
"Hi"
])
def
test_echo_with_multiple_token_ids_streaming
(
self
,
serving_completion
):
"""Test echo handling with multiple token ID prompts in streaming"""
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
[[
1
,
2
],
[
3
,
4
]],
max_tokens
=
100
,
echo
=
True
,
n
=
1
)
serving_completion
.
tokenizer_manager
.
tokenizer
.
decode
.
return_value
=
"decoded"
echo_text
=
serving_completion
.
_get_echo_text
(
request
,
0
)
assert
echo_text
==
"decoded"
def
test_prepare_echo_prompts_non_streaming
(
self
,
serving_completion
):
"""Test prepare echo prompts for non-streaming response"""
# Test with single string
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
"Hello"
,
echo
=
True
)
echo_prompts
=
serving_completion
.
_prepare_echo_prompts
(
request
)
assert
echo_prompts
==
[
"Hello"
]
# Test with list of strings
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
[
"Hello"
,
"World"
],
echo
=
True
)
# list of strings
req
=
CompletionRequest
(
model
=
"x"
,
prompt
=
[
"Hi"
,
"Yo"
],
echo
=
True
)
self
.
assertEqual
(
self
.
sc
.
_prepare_echo_prompts
(
req
),
[
"Hi"
,
"Yo"
])
echo_prompts
=
serving_completion
.
_prepare_echo_prompts
(
request
)
assert
echo_prompts
==
[
"Hello"
,
"World"
]
# token IDs
req
=
CompletionRequest
(
model
=
"x"
,
prompt
=
[
1
,
2
,
3
],
echo
=
True
)
self
.
sc
.
tokenizer_manager
.
tokenizer
.
decode
.
return_value
=
"decoded"
self
.
assertEqual
(
self
.
sc
.
_prepare_echo_prompts
(
req
),
[
"decoded"
])
# Test with token IDs
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
[
1
,
2
,
3
],
echo
=
True
)
serving_completion
.
tokenizer_manager
.
tokenizer
.
decode
.
return_value
=
"decoded"
echo_prompts
=
serving_completion
.
_prepare_echo_prompts
(
request
)
assert
echo_prompts
==
[
"decoded"
]
if
__name__
==
"__main__"
:
unittest
.
main
(
verbosity
=
2
)
test/srt/openai/test_serving_embedding.py
View file @
ffd1a26e
...
...
@@ -8,11 +8,11 @@ with the original adapter.py functionality and follows OpenAI API specifications
import
asyncio
import
json
import
time
import
unittest
import
uuid
from
typing
import
Any
,
Dict
,
List
from
unittest.mock
import
AsyncMock
,
Mock
,
patch
import
pytest
from
fastapi
import
Request
from
fastapi.responses
import
ORJSONResponse
from
pydantic_core
import
ValidationError
...
...
@@ -30,7 +30,7 @@ from sglang.srt.managers.io_struct import EmbeddingReqInput
# Mock TokenizerManager for embedding tests
class
MockTokenizerManager
:
class
_
MockTokenizerManager
:
def
__init__
(
self
):
self
.
model_config
=
Mock
()
self
.
model_config
.
is_multimodal
=
False
...
...
@@ -58,141 +58,98 @@ class MockTokenizerManager:
self
.
generate_request
=
Mock
(
return_value
=
mock_generate_embedding
())
@
pytest
.
fixture
def
mock_tokenizer_manager
():
"""Create a mock tokenizer manager for testing."""
return
MockTokenizerManager
()
class
ServingEmbeddingTestCase
(
unittest
.
TestCase
):
def
setUp
(
self
):
"""Set up test fixtures."""
self
.
tokenizer_manager
=
_MockTokenizerManager
()
self
.
serving_embedding
=
OpenAIServingEmbedding
(
self
.
tokenizer_manager
)
self
.
request
=
Mock
(
spec
=
Request
)
self
.
request
.
headers
=
{}
@
pytest
.
fixture
def
serving_embedding
(
mock_tokenizer_manager
):
"""Create an OpenAIServingEmbedding instance for testing."""
return
OpenAIServingEmbedding
(
mock_tokenizer_manager
)
@
pytest
.
fixture
def
mock_request
():
"""Create a mock FastAPI request."""
request
=
Mock
(
spec
=
Request
)
request
.
headers
=
{}
return
request
@
pytest
.
fixture
def
basic_embedding_request
():
"""Create a basic embedding request."""
return
EmbeddingRequest
(
model
=
"test-model"
,
input
=
"Hello, how are you?"
,
encoding_format
=
"float"
,
)
@
pytest
.
fixture
def
list_embedding_request
():
"""Create an embedding request with list input."""
return
EmbeddingRequest
(
model
=
"test-model"
,
input
=
[
"Hello, how are you?"
,
"I am fine, thank you!"
],
encoding_format
=
"float"
,
)
@
pytest
.
fixture
def
multimodal_embedding_request
():
"""Create a multimodal embedding request."""
return
EmbeddingRequest
(
model
=
"test-model"
,
input
=
[
MultimodalEmbeddingInput
(
text
=
"Hello"
,
image
=
"base64_image_data"
),
MultimodalEmbeddingInput
(
text
=
"World"
,
image
=
None
),
],
encoding_format
=
"float"
,
)
@
pytest
.
fixture
def
token_ids_embedding_request
():
"""Create an embedding request with token IDs."""
return
EmbeddingRequest
(
model
=
"test-model"
,
input
=
[
1
,
2
,
3
,
4
,
5
],
encoding_format
=
"float"
,
)
class
TestOpenAIServingEmbeddingConversion
:
"""Test request conversion methods."""
self
.
basic_req
=
EmbeddingRequest
(
model
=
"test-model"
,
input
=
"Hello, how are you?"
,
encoding_format
=
"float"
,
)
self
.
list_req
=
EmbeddingRequest
(
model
=
"test-model"
,
input
=
[
"Hello, how are you?"
,
"I am fine, thank you!"
],
encoding_format
=
"float"
,
)
self
.
multimodal_req
=
EmbeddingRequest
(
model
=
"test-model"
,
input
=
[
MultimodalEmbeddingInput
(
text
=
"Hello"
,
image
=
"base64_image_data"
),
MultimodalEmbeddingInput
(
text
=
"World"
,
image
=
None
),
],
encoding_format
=
"float"
,
)
self
.
token_ids_req
=
EmbeddingRequest
(
model
=
"test-model"
,
input
=
[
1
,
2
,
3
,
4
,
5
],
encoding_format
=
"float"
,
)
def
test_convert_single_string_request
(
self
,
serving_embedding
,
basic_embedding_request
):
def
test_convert_single_string_request
(
self
):
"""Test converting single string request to internal format."""
adapted_request
,
processed_request
=
(
serving_embedding
.
_convert_to_internal_request
(
[
basic_
embedding_request
],
[
"test-id"
]
self
.
serving_embedding
.
_convert_to_internal_request
(
[
self
.
basic_
req
],
[
"test-id"
]
)
)
assert
isi
nstance
(
adapted_request
,
EmbeddingReqInput
)
assert
adapted_request
.
text
==
"Hello, how are you?"
assert
adapted_request
.
rid
==
"test-id"
assert
processed_request
==
basic_embedding_request
self
.
assert
IsI
nstance
(
adapted_request
,
EmbeddingReqInput
)
self
.
assert
Equal
(
adapted_request
.
text
,
"Hello, how are you?"
)
self
.
assert
Equal
(
adapted_request
.
rid
,
"test-id"
)
self
.
assert
Equal
(
processed_request
,
self
.
basic_req
)
def
test_convert_list_string_request
(
self
,
serving_embedding
,
list_embedding_request
):
def
test_convert_list_string_request
(
self
):
"""Test converting list of strings request to internal format."""
adapted_request
,
processed_request
=
(
serving_embedding
.
_convert_to_internal_request
(
[
list_embedding_request
],
[
"test-id"
]
self
.
serving_embedding
.
_convert_to_internal_request
(
[
self
.
list_req
],
[
"test-id"
]
)
)
assert
isinstance
(
adapted_request
,
EmbeddingReqInput
)
assert
adapted_request
.
text
==
[
"Hello, how are you?"
,
"I am fine, thank you!"
]
assert
adapted_request
.
rid
==
"test-id"
assert
processed_request
==
list_embedding_request
self
.
assertIsInstance
(
adapted_request
,
EmbeddingReqInput
)
self
.
assertEqual
(
adapted_request
.
text
,
[
"Hello, how are you?"
,
"I am fine, thank you!"
]
)
self
.
assertEqual
(
adapted_request
.
rid
,
"test-id"
)
self
.
assertEqual
(
processed_request
,
self
.
list_req
)
def
test_convert_token_ids_request
(
self
,
serving_embedding
,
token_ids_embedding_request
):
def
test_convert_token_ids_request
(
self
):
"""Test converting token IDs request to internal format."""
adapted_request
,
processed_request
=
(
serving_embedding
.
_convert_to_internal_request
(
[
token_ids_
embedding_request
],
[
"test-id"
]
self
.
serving_embedding
.
_convert_to_internal_request
(
[
self
.
token_ids_
req
],
[
"test-id"
]
)
)
assert
isi
nstance
(
adapted_request
,
EmbeddingReqInput
)
assert
adapted_request
.
input_ids
==
[
1
,
2
,
3
,
4
,
5
]
assert
adapted_request
.
rid
==
"test-id"
assert
processed_request
==
token_ids_
embedding_request
self
.
assert
IsI
nstance
(
adapted_request
,
EmbeddingReqInput
)
self
.
assert
Equal
(
adapted_request
.
input_ids
,
[
1
,
2
,
3
,
4
,
5
]
)
self
.
assert
Equal
(
adapted_request
.
rid
,
"test-id"
)
self
.
assert
Equal
(
processed_request
,
self
.
token_ids_
req
)
def
test_convert_multimodal_request
(
self
,
serving_embedding
,
multimodal_embedding_request
):
def
test_convert_multimodal_request
(
self
):
"""Test converting multimodal request to internal format."""
adapted_request
,
processed_request
=
(
serving_embedding
.
_convert_to_internal_request
(
[
multimodal_
embedding_request
],
[
"test-id"
]
self
.
serving_embedding
.
_convert_to_internal_request
(
[
self
.
multimodal_
req
],
[
"test-id"
]
)
)
assert
isi
nstance
(
adapted_request
,
EmbeddingReqInput
)
self
.
assert
IsI
nstance
(
adapted_request
,
EmbeddingReqInput
)
# Should extract text and images separately
assert
len
(
adapted_request
.
text
)
==
2
assert
"Hello"
in
adapted_request
.
text
assert
"World"
in
adapted_request
.
text
assert
adapted_request
.
image_data
[
0
]
==
"base64_image_data"
assert
adapted_request
.
image_data
[
1
]
is
None
assert
adapted_request
.
rid
==
"test-id"
class
TestEmbeddingResponseBuilding
:
"""Test response building methods."""
def
test_build_single_embedding_response
(
self
,
serving_embedding
):
self
.
assertEqual
(
len
(
adapted_request
.
text
),
2
)
self
.
assertIn
(
"Hello"
,
adapted_request
.
text
)
self
.
assertIn
(
"World"
,
adapted_request
.
text
)
self
.
assertEqual
(
adapted_request
.
image_data
[
0
],
"base64_image_data"
)
self
.
assertIsNone
(
adapted_request
.
image_data
[
1
])
self
.
assertEqual
(
adapted_request
.
rid
,
"test-id"
)
def
test_build_single_embedding_response
(
self
):
"""Test building response for single embedding."""
ret_data
=
[
{
...
...
@@ -201,19 +158,21 @@ class TestEmbeddingResponseBuilding:
}
]
response
=
serving_embedding
.
_build_embedding_response
(
ret_data
,
"test-model"
)
assert
isinstance
(
response
,
EmbeddingResponse
)
assert
response
.
model
==
"test-model"
assert
len
(
response
.
data
)
==
1
assert
response
.
data
[
0
].
embedding
==
[
0.1
,
0.2
,
0.3
,
0.4
,
0.5
]
assert
response
.
data
[
0
].
index
==
0
assert
response
.
data
[
0
].
object
==
"embedding"
assert
response
.
usage
.
prompt_tokens
==
5
assert
response
.
usage
.
total_tokens
==
5
assert
response
.
usage
.
completion_tokens
==
0
response
=
self
.
serving_embedding
.
_build_embedding_response
(
ret_data
,
"test-model"
)
def
test_build_multiple_embedding_response
(
self
,
serving_embedding
):
self
.
assertIsInstance
(
response
,
EmbeddingResponse
)
self
.
assertEqual
(
response
.
model
,
"test-model"
)
self
.
assertEqual
(
len
(
response
.
data
),
1
)
self
.
assertEqual
(
response
.
data
[
0
].
embedding
,
[
0.1
,
0.2
,
0.3
,
0.4
,
0.5
])
self
.
assertEqual
(
response
.
data
[
0
].
index
,
0
)
self
.
assertEqual
(
response
.
data
[
0
].
object
,
"embedding"
)
self
.
assertEqual
(
response
.
usage
.
prompt_tokens
,
5
)
self
.
assertEqual
(
response
.
usage
.
total_tokens
,
5
)
self
.
assertEqual
(
response
.
usage
.
completion_tokens
,
0
)
def
test_build_multiple_embedding_response
(
self
):
"""Test building response for multiple embeddings."""
ret_data
=
[
{
...
...
@@ -226,25 +185,20 @@ class TestEmbeddingResponseBuilding:
},
]
response
=
serving_embedding
.
_build_embedding_response
(
ret_data
,
"test-model"
)
assert
isinstance
(
response
,
EmbeddingResponse
)
assert
len
(
response
.
data
)
==
2
assert
response
.
data
[
0
].
embedding
==
[
0.1
,
0.2
,
0.3
]
assert
response
.
data
[
0
].
index
==
0
assert
response
.
data
[
1
].
embedding
==
[
0.4
,
0.5
,
0.6
]
assert
response
.
data
[
1
].
index
==
1
assert
response
.
usage
.
prompt_tokens
==
7
# 3 + 4
assert
response
.
usage
.
total_tokens
==
7
response
=
self
.
serving_embedding
.
_build_embedding_response
(
ret_data
,
"test-model"
)
@
pytest
.
mark
.
asyncio
class
TestOpenAIServingEmbeddingAsyncMethods
:
"""Test async methods of OpenAIServingEmbedding."""
self
.
assertIsInstance
(
response
,
EmbeddingResponse
)
self
.
assertEqual
(
len
(
response
.
data
),
2
)
self
.
assertEqual
(
response
.
data
[
0
].
embedding
,
[
0.1
,
0.2
,
0.3
])
self
.
assertEqual
(
response
.
data
[
0
].
index
,
0
)
self
.
assertEqual
(
response
.
data
[
1
].
embedding
,
[
0.4
,
0.5
,
0.6
])
self
.
assertEqual
(
response
.
data
[
1
].
index
,
1
)
self
.
assertEqual
(
response
.
usage
.
prompt_tokens
,
7
)
# 3 + 4
self
.
assertEqual
(
response
.
usage
.
total_tokens
,
7
)
async
def
test_handle_request_success
(
self
,
serving_embedding
,
basic_embedding_request
,
mock_request
):
async
def
test_handle_request_success
(
self
):
"""Test successful embedding request handling."""
# Mock the generate_request to return expected data
...
...
@@ -254,32 +208,30 @@ class TestOpenAIServingEmbeddingAsyncMethods:
"meta_info"
:
{
"prompt_tokens"
:
5
},
}
serving_embedding
.
tokenizer_manager
.
generate_request
=
Mock
(
self
.
serving_embedding
.
tokenizer_manager
.
generate_request
=
Mock
(
return_value
=
mock_generate
()
)
response
=
await
serving_embedding
.
handle_request
(
basic_embedding_request
,
mock_
request
response
=
await
self
.
serving_embedding
.
handle_request
(
self
.
basic_req
,
self
.
request
)
assert
isi
nstance
(
response
,
EmbeddingResponse
)
assert
len
(
response
.
data
)
==
1
assert
response
.
data
[
0
].
embedding
==
[
0.1
,
0.2
,
0.3
,
0.4
,
0.5
]
self
.
assert
IsI
nstance
(
response
,
EmbeddingResponse
)
self
.
assert
Equal
(
len
(
response
.
data
)
,
1
)
self
.
assert
Equal
(
response
.
data
[
0
].
embedding
,
[
0.1
,
0.2
,
0.3
,
0.4
,
0.5
]
)
async
def
test_handle_request_validation_error
(
self
,
serving_embedding
,
mock_request
):
async
def
test_handle_request_validation_error
(
self
):
"""Test handling request with validation error."""
invalid_request
=
EmbeddingRequest
(
model
=
"test-model"
,
input
=
""
)
response
=
await
serving_embedding
.
handle_request
(
invalid_request
,
mock_request
)
response
=
await
self
.
serving_embedding
.
handle_request
(
invalid_request
,
self
.
request
)
assert
isi
nstance
(
response
,
ORJSONResponse
)
assert
response
.
status_code
==
400
self
.
assert
IsI
nstance
(
response
,
ORJSONResponse
)
self
.
assert
Equal
(
response
.
status_code
,
400
)
async
def
test_handle_request_generation_error
(
self
,
serving_embedding
,
basic_embedding_request
,
mock_request
):
async
def
test_handle_request_generation_error
(
self
):
"""Test handling request with generation error."""
# Mock generate_request to raise an error
...
...
@@ -287,30 +239,32 @@ class TestOpenAIServingEmbeddingAsyncMethods:
raise
ValueError
(
"Generation failed"
)
yield
# This won't be reached but needed for async generator
serving_embedding
.
tokenizer_manager
.
generate_request
=
Mock
(
self
.
serving_embedding
.
tokenizer_manager
.
generate_request
=
Mock
(
return_value
=
mock_generate_error
()
)
response
=
await
serving_embedding
.
handle_request
(
basic_embedding_request
,
mock_
request
response
=
await
self
.
serving_embedding
.
handle_request
(
self
.
basic_req
,
self
.
request
)
assert
isi
nstance
(
response
,
ORJSONResponse
)
assert
response
.
status_code
==
400
self
.
assert
IsI
nstance
(
response
,
ORJSONResponse
)
self
.
assert
Equal
(
response
.
status_code
,
400
)
async
def
test_handle_request_internal_error
(
self
,
serving_embedding
,
basic_embedding_request
,
mock_request
):
async
def
test_handle_request_internal_error
(
self
):
"""Test handling request with internal server error."""
# Mock _convert_to_internal_request to raise an exception
with
patch
.
object
(
serving_embedding
,
self
.
serving_embedding
,
"_convert_to_internal_request"
,
side_effect
=
Exception
(
"Internal error"
),
):
response
=
await
serving_embedding
.
handle_request
(
basic_embedding_request
,
mock_
request
response
=
await
self
.
serving_embedding
.
handle_request
(
self
.
basic_req
,
self
.
request
)
assert
isinstance
(
response
,
ORJSONResponse
)
assert
response
.
status_code
==
500
self
.
assertIsInstance
(
response
,
ORJSONResponse
)
self
.
assertEqual
(
response
.
status_code
,
500
)
if
__name__
==
"__main__"
:
unittest
.
main
(
verbosity
=
2
)
test/srt/run_suite.py
View file @
ffd1a26e
...
...
@@ -62,6 +62,11 @@ suites = {
TestFile
(
"test_openai_adapter.py"
,
1
),
TestFile
(
"test_openai_function_calling.py"
,
60
),
TestFile
(
"test_openai_server.py"
,
149
),
TestFile
(
"openai/test_server.py"
,
120
),
TestFile
(
"openai/test_protocol.py"
,
60
),
TestFile
(
"openai/test_serving_chat.py"
,
120
),
TestFile
(
"openai/test_serving_completions.py"
,
120
),
TestFile
(
"openai/test_serving_embedding.py"
,
120
),
TestFile
(
"test_openai_server_hidden_states.py"
,
240
),
TestFile
(
"test_penalty.py"
,
41
),
TestFile
(
"test_page_size.py"
,
60
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment