Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ffd1a26e
Unverified
Commit
ffd1a26e
authored
Jun 18, 2025
by
Jinn
Committed by
GitHub
Jun 18, 2025
Browse files
Add more refactored openai test & in CI (#7284)
parent
09ae5b20
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
576 additions
and
1059 deletions
+576
-1059
python/sglang/srt/entrypoints/openai/api_server.py
python/sglang/srt/entrypoints/openai/api_server.py
+2
-2
test/srt/openai/conftest.py
test/srt/openai/conftest.py
+4
-3
test/srt/openai/test_protocol.py
test/srt/openai/test_protocol.py
+179
-177
test/srt/openai/test_server.py
test/srt/openai/test_server.py
+41
-5
test/srt/openai/test_serving_chat.py
test/srt/openai/test_serving_chat.py
+156
-562
test/srt/openai/test_serving_completions.py
test/srt/openai/test_serving_completions.py
+68
-143
test/srt/openai/test_serving_embedding.py
test/srt/openai/test_serving_embedding.py
+121
-167
test/srt/run_suite.py
test/srt/run_suite.py
+5
-0
No files found.
python/sglang/srt/entrypoints/openai/api_server.py
View file @
ffd1a26e
...
@@ -36,7 +36,7 @@ from fastapi.middleware.cors import CORSMiddleware
...
@@ -36,7 +36,7 @@ from fastapi.middleware.cors import CORSMiddleware
from
fastapi.responses
import
Response
from
fastapi.responses
import
Response
from
sglang.srt.disaggregation.utils
import
(
from
sglang.srt.disaggregation.utils
import
(
F
akeBootstrapHost
,
F
AKE_BOOTSTRAP_HOST
,
register_disaggregation_server
,
register_disaggregation_server
,
)
)
from
sglang.srt.entrypoints.engine
import
Engine
,
_launch_subprocesses
from
sglang.srt.entrypoints.engine
import
Engine
,
_launch_subprocesses
...
@@ -265,7 +265,7 @@ def _wait_and_warmup(
...
@@ -265,7 +265,7 @@ def _wait_and_warmup(
"max_new_tokens"
:
8
,
"max_new_tokens"
:
8
,
"ignore_eos"
:
True
,
"ignore_eos"
:
True
,
},
},
"bootstrap_host"
:
[
F
akeBootstrapHost
]
*
server_args
.
dp_size
,
"bootstrap_host"
:
[
F
AKE_BOOTSTRAP_HOST
]
*
server_args
.
dp_size
,
# This is a hack to ensure fake transfer is enabled during prefill warmup
# This is a hack to ensure fake transfer is enabled during prefill warmup
# ensure each dp rank has a unique bootstrap_room during prefill warmup
# ensure each dp rank has a unique bootstrap_room during prefill warmup
"bootstrap_room"
:
[
"bootstrap_room"
:
[
...
...
test/srt/openai/conftest.py
View file @
ffd1a26e
...
@@ -12,9 +12,10 @@ import pytest
...
@@ -12,9 +12,10 @@ import pytest
import
requests
import
requests
from
sglang.srt.utils
import
kill_process_tree
# reuse SGLang helper
from
sglang.srt.utils
import
kill_process_tree
# reuse SGLang helper
from
sglang.test.test_utils
import
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
SERVER_MODULE
=
"sglang.srt.entrypoints.openai.api_server"
SERVER_MODULE
=
"sglang.srt.entrypoints.openai.api_server"
DEFAULT_MODEL
=
"dummy-model"
DEFAULT_MODEL
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
STARTUP_TIMEOUT
=
float
(
os
.
getenv
(
"SGLANG_OPENAI_STARTUP_TIMEOUT"
,
120
))
STARTUP_TIMEOUT
=
float
(
os
.
getenv
(
"SGLANG_OPENAI_STARTUP_TIMEOUT"
,
120
))
...
@@ -39,7 +40,7 @@ def _wait_until_healthy(proc: subprocess.Popen, base: str, timeout: float) -> No
...
@@ -39,7 +40,7 @@ def _wait_until_healthy(proc: subprocess.Popen, base: str, timeout: float) -> No
def
launch_openai_server
(
model
:
str
=
DEFAULT_MODEL
,
**
kw
):
def
launch_openai_server
(
model
:
str
=
DEFAULT_MODEL
,
**
kw
):
"""Spawn the draft OpenAI-compatible server and wait until it
’
s ready."""
"""Spawn the draft OpenAI-compatible server and wait until it
'
s ready."""
port
=
_pick_free_port
()
port
=
_pick_free_port
()
cmd
=
[
cmd
=
[
sys
.
executable
,
sys
.
executable
,
...
@@ -79,7 +80,7 @@ def launch_openai_server(model: str = DEFAULT_MODEL, **kw):
...
@@ -79,7 +80,7 @@ def launch_openai_server(model: str = DEFAULT_MODEL, **kw):
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
openai_server
()
->
Generator
[
str
,
None
,
None
]:
def
openai_server
()
->
Generator
[
str
,
None
,
None
]:
"""PyTest fixture that provides the server
’
s base URL and cleans up."""
"""PyTest fixture that provides the server
'
s base URL and cleans up."""
proc
,
base
,
log_file
=
launch_openai_server
()
proc
,
base
,
log_file
=
launch_openai_server
()
yield
base
yield
base
kill_process_tree
(
proc
.
pid
)
kill_process_tree
(
proc
.
pid
)
...
...
test/srt/openai/test_protocol.py
View file @
ffd1a26e
...
@@ -15,9 +15,9 @@
...
@@ -15,9 +15,9 @@
import
json
import
json
import
time
import
time
import
unittest
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
import
pytest
from
pydantic
import
ValidationError
from
pydantic
import
ValidationError
from
sglang.srt.entrypoints.openai.protocol
import
(
from
sglang.srt.entrypoints.openai.protocol
import
(
...
@@ -64,18 +64,18 @@ from sglang.srt.entrypoints.openai.protocol import (
...
@@ -64,18 +64,18 @@ from sglang.srt.entrypoints.openai.protocol import (
)
)
class
TestModelCard
:
class
TestModelCard
(
unittest
.
TestCase
)
:
"""Test ModelCard protocol model"""
"""Test ModelCard protocol model"""
def
test_basic_model_card_creation
(
self
):
def
test_basic_model_card_creation
(
self
):
"""Test basic model card creation with required fields"""
"""Test basic model card creation with required fields"""
card
=
ModelCard
(
id
=
"test-model"
)
card
=
ModelCard
(
id
=
"test-model"
)
assert
card
.
id
==
"test-model"
self
.
assert
Equal
(
card
.
id
,
"test-model"
)
assert
card
.
object
==
"model"
self
.
assert
Equal
(
card
.
object
,
"model"
)
assert
card
.
owned_by
==
"sglang"
self
.
assert
Equal
(
card
.
owned_by
,
"sglang"
)
assert
isi
nstance
(
card
.
created
,
int
)
self
.
assert
IsI
nstance
(
card
.
created
,
int
)
assert
card
.
root
is
None
self
.
assert
IsNone
(
card
.
root
)
assert
card
.
max_model_len
is
None
self
.
assert
IsNone
(
card
.
max_model_len
)
def
test_model_card_with_optional_fields
(
self
):
def
test_model_card_with_optional_fields
(
self
):
"""Test model card with optional fields"""
"""Test model card with optional fields"""
...
@@ -85,28 +85,28 @@ class TestModelCard:
...
@@ -85,28 +85,28 @@ class TestModelCard:
max_model_len
=
2048
,
max_model_len
=
2048
,
created
=
1234567890
,
created
=
1234567890
,
)
)
assert
card
.
id
==
"test-model"
self
.
assert
Equal
(
card
.
id
,
"test-model"
)
assert
card
.
root
==
"/path/to/model"
self
.
assert
Equal
(
card
.
root
,
"/path/to/model"
)
assert
card
.
max_model_len
==
2048
self
.
assert
Equal
(
card
.
max_model_len
,
2048
)
assert
card
.
created
==
1234567890
self
.
assert
Equal
(
card
.
created
,
1234567890
)
def
test_model_card_serialization
(
self
):
def
test_model_card_serialization
(
self
):
"""Test model card JSON serialization"""
"""Test model card JSON serialization"""
card
=
ModelCard
(
id
=
"test-model"
,
max_model_len
=
4096
)
card
=
ModelCard
(
id
=
"test-model"
,
max_model_len
=
4096
)
data
=
card
.
model_dump
()
data
=
card
.
model_dump
()
assert
data
[
"id"
]
==
"test-model"
self
.
assert
Equal
(
data
[
"id"
]
,
"test-model"
)
assert
data
[
"object"
]
==
"model"
self
.
assert
Equal
(
data
[
"object"
]
,
"model"
)
assert
data
[
"max_model_len"
]
==
4096
self
.
assert
Equal
(
data
[
"max_model_len"
]
,
4096
)
class
TestModelList
:
class
TestModelList
(
unittest
.
TestCase
)
:
"""Test ModelList protocol model"""
"""Test ModelList protocol model"""
def
test_empty_model_list
(
self
):
def
test_empty_model_list
(
self
):
"""Test empty model list creation"""
"""Test empty model list creation"""
model_list
=
ModelList
()
model_list
=
ModelList
()
assert
model_list
.
object
==
"list"
self
.
assert
Equal
(
model_list
.
object
,
"list"
)
assert
len
(
model_list
.
data
)
==
0
self
.
assert
Equal
(
len
(
model_list
.
data
)
,
0
)
def
test_model_list_with_cards
(
self
):
def
test_model_list_with_cards
(
self
):
"""Test model list with model cards"""
"""Test model list with model cards"""
...
@@ -115,12 +115,12 @@ class TestModelList:
...
@@ -115,12 +115,12 @@ class TestModelList:
ModelCard
(
id
=
"model-2"
,
max_model_len
=
2048
),
ModelCard
(
id
=
"model-2"
,
max_model_len
=
2048
),
]
]
model_list
=
ModelList
(
data
=
cards
)
model_list
=
ModelList
(
data
=
cards
)
assert
len
(
model_list
.
data
)
==
2
self
.
assert
Equal
(
len
(
model_list
.
data
)
,
2
)
assert
model_list
.
data
[
0
].
id
==
"model-1"
self
.
assert
Equal
(
model_list
.
data
[
0
].
id
,
"model-1"
)
assert
model_list
.
data
[
1
].
id
==
"model-2"
self
.
assert
Equal
(
model_list
.
data
[
1
].
id
,
"model-2"
)
class
TestErrorResponse
:
class
TestErrorResponse
(
unittest
.
TestCase
)
:
"""Test ErrorResponse protocol model"""
"""Test ErrorResponse protocol model"""
def
test_basic_error_response
(
self
):
def
test_basic_error_response
(
self
):
...
@@ -128,11 +128,11 @@ class TestErrorResponse:
...
@@ -128,11 +128,11 @@ class TestErrorResponse:
error
=
ErrorResponse
(
error
=
ErrorResponse
(
message
=
"Invalid request"
,
type
=
"BadRequestError"
,
code
=
400
message
=
"Invalid request"
,
type
=
"BadRequestError"
,
code
=
400
)
)
assert
error
.
object
==
"error"
self
.
assert
Equal
(
error
.
object
,
"error"
)
assert
error
.
message
==
"Invalid request"
self
.
assert
Equal
(
error
.
message
,
"Invalid request"
)
assert
error
.
type
==
"BadRequestError"
self
.
assert
Equal
(
error
.
type
,
"BadRequestError"
)
assert
error
.
code
==
400
self
.
assert
Equal
(
error
.
code
,
400
)
assert
error
.
param
is
None
self
.
assert
IsNone
(
error
.
param
)
def
test_error_response_with_param
(
self
):
def
test_error_response_with_param
(
self
):
"""Test error response with parameter"""
"""Test error response with parameter"""
...
@@ -142,19 +142,19 @@ class TestErrorResponse:
...
@@ -142,19 +142,19 @@ class TestErrorResponse:
code
=
422
,
code
=
422
,
param
=
"temperature"
,
param
=
"temperature"
,
)
)
assert
error
.
param
==
"temperature"
self
.
assert
Equal
(
error
.
param
,
"temperature"
)
class
TestUsageInfo
:
class
TestUsageInfo
(
unittest
.
TestCase
)
:
"""Test UsageInfo protocol model"""
"""Test UsageInfo protocol model"""
def
test_basic_usage_info
(
self
):
def
test_basic_usage_info
(
self
):
"""Test basic usage info creation"""
"""Test basic usage info creation"""
usage
=
UsageInfo
(
prompt_tokens
=
10
,
completion_tokens
=
20
,
total_tokens
=
30
)
usage
=
UsageInfo
(
prompt_tokens
=
10
,
completion_tokens
=
20
,
total_tokens
=
30
)
assert
usage
.
prompt_tokens
==
10
self
.
assert
Equal
(
usage
.
prompt_tokens
,
10
)
assert
usage
.
completion_tokens
==
20
self
.
assert
Equal
(
usage
.
completion_tokens
,
20
)
assert
usage
.
total_tokens
==
30
self
.
assert
Equal
(
usage
.
total_tokens
,
30
)
assert
usage
.
prompt_tokens_details
is
None
self
.
assert
IsNone
(
usage
.
prompt_tokens_details
)
def
test_usage_info_with_cache_details
(
self
):
def
test_usage_info_with_cache_details
(
self
):
"""Test usage info with cache details"""
"""Test usage info with cache details"""
...
@@ -164,22 +164,22 @@ class TestUsageInfo:
...
@@ -164,22 +164,22 @@ class TestUsageInfo:
total_tokens
=
30
,
total_tokens
=
30
,
prompt_tokens_details
=
{
"cached_tokens"
:
5
},
prompt_tokens_details
=
{
"cached_tokens"
:
5
},
)
)
assert
usage
.
prompt_tokens_details
==
{
"cached_tokens"
:
5
}
self
.
assert
Equal
(
usage
.
prompt_tokens_details
,
{
"cached_tokens"
:
5
}
)
class
TestCompletionRequest
:
class
TestCompletionRequest
(
unittest
.
TestCase
)
:
"""Test CompletionRequest protocol model"""
"""Test CompletionRequest protocol model"""
def
test_basic_completion_request
(
self
):
def
test_basic_completion_request
(
self
):
"""Test basic completion request"""
"""Test basic completion request"""
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
"Hello world"
)
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
"Hello world"
)
assert
request
.
model
==
"test-model"
self
.
assert
Equal
(
request
.
model
,
"test-model"
)
assert
request
.
prompt
==
"Hello world"
self
.
assert
Equal
(
request
.
prompt
,
"Hello world"
)
assert
request
.
max_tokens
==
16
# default
self
.
assert
Equal
(
request
.
max_tokens
,
16
)
# default
assert
request
.
temperature
==
1.0
# default
self
.
assert
Equal
(
request
.
temperature
,
1.0
)
# default
assert
request
.
n
==
1
# default
self
.
assert
Equal
(
request
.
n
,
1
)
# default
assert
not
request
.
stream
# default
self
.
assert
False
(
request
.
stream
)
# default
assert
not
request
.
echo
# default
self
.
assert
False
(
request
.
echo
)
# default
def
test_completion_request_with_options
(
self
):
def
test_completion_request_with_options
(
self
):
"""Test completion request with various options"""
"""Test completion request with various options"""
...
@@ -195,15 +195,15 @@ class TestCompletionRequest:
...
@@ -195,15 +195,15 @@ class TestCompletionRequest:
stop
=
[
"."
,
"!"
],
stop
=
[
"."
,
"!"
],
logprobs
=
5
,
logprobs
=
5
,
)
)
assert
request
.
prompt
==
[
"Hello"
,
"world"
]
self
.
assert
Equal
(
request
.
prompt
,
[
"Hello"
,
"world"
]
)
assert
request
.
max_tokens
==
100
self
.
assert
Equal
(
request
.
max_tokens
,
100
)
assert
request
.
temperature
==
0.7
self
.
assert
Equal
(
request
.
temperature
,
0.7
)
assert
request
.
top_p
==
0.9
self
.
assert
Equal
(
request
.
top_p
,
0.9
)
assert
request
.
n
==
2
self
.
assert
Equal
(
request
.
n
,
2
)
assert
request
.
stream
self
.
assert
True
(
request
.
stream
)
assert
request
.
echo
self
.
assert
True
(
request
.
echo
)
assert
request
.
stop
==
[
"."
,
"!"
]
self
.
assert
Equal
(
request
.
stop
,
[
"."
,
"!"
]
)
assert
request
.
logprobs
==
5
self
.
assert
Equal
(
request
.
logprobs
,
5
)
def
test_completion_request_sglang_extensions
(
self
):
def
test_completion_request_sglang_extensions
(
self
):
"""Test completion request with SGLang-specific extensions"""
"""Test completion request with SGLang-specific extensions"""
...
@@ -217,23 +217,23 @@ class TestCompletionRequest:
...
@@ -217,23 +217,23 @@ class TestCompletionRequest:
json_schema
=
'{"type": "object"}'
,
json_schema
=
'{"type": "object"}'
,
lora_path
=
"/path/to/lora"
,
lora_path
=
"/path/to/lora"
,
)
)
assert
request
.
top_k
==
50
self
.
assert
Equal
(
request
.
top_k
,
50
)
assert
request
.
min_p
==
0.1
self
.
assert
Equal
(
request
.
min_p
,
0.1
)
assert
request
.
repetition_penalty
==
1.1
self
.
assert
Equal
(
request
.
repetition_penalty
,
1.1
)
assert
request
.
regex
==
r
"\d+"
self
.
assert
Equal
(
request
.
regex
,
r
"\d+"
)
assert
request
.
json_schema
==
'{"type": "object"}'
self
.
assert
Equal
(
request
.
json_schema
,
'{"type": "object"}'
)
assert
request
.
lora_path
==
"/path/to/lora"
self
.
assert
Equal
(
request
.
lora_path
,
"/path/to/lora"
)
def
test_completion_request_validation_errors
(
self
):
def
test_completion_request_validation_errors
(
self
):
"""Test completion request validation errors"""
"""Test completion request validation errors"""
with
pytest
.
r
aises
(
ValidationError
):
with
self
.
assertR
aises
(
ValidationError
):
CompletionRequest
()
# missing required fields
CompletionRequest
()
# missing required fields
with
pytest
.
r
aises
(
ValidationError
):
with
self
.
assertR
aises
(
ValidationError
):
CompletionRequest
(
model
=
"test-model"
)
# missing prompt
CompletionRequest
(
model
=
"test-model"
)
# missing prompt
class
TestCompletionResponse
:
class
TestCompletionResponse
(
unittest
.
TestCase
)
:
"""Test CompletionResponse protocol model"""
"""Test CompletionResponse protocol model"""
def
test_basic_completion_response
(
self
):
def
test_basic_completion_response
(
self
):
...
@@ -245,28 +245,28 @@ class TestCompletionResponse:
...
@@ -245,28 +245,28 @@ class TestCompletionResponse:
response
=
CompletionResponse
(
response
=
CompletionResponse
(
id
=
"test-id"
,
model
=
"test-model"
,
choices
=
[
choice
],
usage
=
usage
id
=
"test-id"
,
model
=
"test-model"
,
choices
=
[
choice
],
usage
=
usage
)
)
assert
response
.
id
==
"test-id"
self
.
assert
Equal
(
response
.
id
,
"test-id"
)
assert
response
.
object
==
"text_completion"
self
.
assert
Equal
(
response
.
object
,
"text_completion"
)
assert
response
.
model
==
"test-model"
self
.
assert
Equal
(
response
.
model
,
"test-model"
)
assert
len
(
response
.
choices
)
==
1
self
.
assert
Equal
(
len
(
response
.
choices
)
,
1
)
assert
response
.
choices
[
0
].
text
==
"Hello world!"
self
.
assert
Equal
(
response
.
choices
[
0
].
text
,
"Hello world!"
)
assert
response
.
usage
.
total_tokens
==
5
self
.
assert
Equal
(
response
.
usage
.
total_tokens
,
5
)
class
TestChatCompletionRequest
:
class
TestChatCompletionRequest
(
unittest
.
TestCase
)
:
"""Test ChatCompletionRequest protocol model"""
"""Test ChatCompletionRequest protocol model"""
def
test_basic_chat_completion_request
(
self
):
def
test_basic_chat_completion_request
(
self
):
"""Test basic chat completion request"""
"""Test basic chat completion request"""
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello"
}]
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello"
}]
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
messages
)
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
messages
)
assert
request
.
model
==
"test-model"
self
.
assert
Equal
(
request
.
model
,
"test-model"
)
assert
len
(
request
.
messages
)
==
1
self
.
assert
Equal
(
len
(
request
.
messages
)
,
1
)
assert
request
.
messages
[
0
].
role
==
"user"
self
.
assert
Equal
(
request
.
messages
[
0
].
role
,
"user"
)
assert
request
.
messages
[
0
].
content
==
"Hello"
self
.
assert
Equal
(
request
.
messages
[
0
].
content
,
"Hello"
)
assert
request
.
temperature
==
0.7
# default
self
.
assert
Equal
(
request
.
temperature
,
0.7
)
# default
assert
not
request
.
stream
# default
self
.
assert
False
(
request
.
stream
)
# default
assert
request
.
tool_choice
==
"none"
# default when no tools
self
.
assert
Equal
(
request
.
tool_choice
,
"none"
)
# default when no tools
def
test_chat_completion_with_multimodal_content
(
self
):
def
test_chat_completion_with_multimodal_content
(
self
):
"""Test chat completion with multimodal content"""
"""Test chat completion with multimodal content"""
...
@@ -283,9 +283,9 @@ class TestChatCompletionRequest:
...
@@ -283,9 +283,9 @@ class TestChatCompletionRequest:
}
}
]
]
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
messages
)
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
messages
)
assert
len
(
request
.
messages
[
0
].
content
)
==
2
self
.
assert
Equal
(
len
(
request
.
messages
[
0
].
content
)
,
2
)
assert
request
.
messages
[
0
].
content
[
0
].
type
==
"text"
self
.
assert
Equal
(
request
.
messages
[
0
].
content
[
0
].
type
,
"text"
)
assert
request
.
messages
[
0
].
content
[
1
].
type
==
"image_url"
self
.
assert
Equal
(
request
.
messages
[
0
].
content
[
1
].
type
,
"image_url"
)
def
test_chat_completion_with_tools
(
self
):
def
test_chat_completion_with_tools
(
self
):
"""Test chat completion with tools"""
"""Test chat completion with tools"""
...
@@ -306,9 +306,9 @@ class TestChatCompletionRequest:
...
@@ -306,9 +306,9 @@ class TestChatCompletionRequest:
request
=
ChatCompletionRequest
(
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
messages
,
tools
=
tools
model
=
"test-model"
,
messages
=
messages
,
tools
=
tools
)
)
assert
len
(
request
.
tools
)
==
1
self
.
assert
Equal
(
len
(
request
.
tools
)
,
1
)
assert
request
.
tools
[
0
].
function
.
name
==
"get_weather"
self
.
assert
Equal
(
request
.
tools
[
0
].
function
.
name
,
"get_weather"
)
assert
request
.
tool_choice
==
"auto"
# default when tools present
self
.
assert
Equal
(
request
.
tool_choice
,
"auto"
)
# default when tools present
def
test_chat_completion_tool_choice_validation
(
self
):
def
test_chat_completion_tool_choice_validation
(
self
):
"""Test tool choice validation logic"""
"""Test tool choice validation logic"""
...
@@ -316,7 +316,7 @@ class TestChatCompletionRequest:
...
@@ -316,7 +316,7 @@ class TestChatCompletionRequest:
# No tools, tool_choice should default to "none"
# No tools, tool_choice should default to "none"
request1
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
messages
)
request1
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
messages
)
assert
request1
.
tool_choice
==
"none"
self
.
assert
Equal
(
request1
.
tool_choice
,
"none"
)
# With tools, tool_choice should default to "auto"
# With tools, tool_choice should default to "auto"
tools
=
[
tools
=
[
...
@@ -328,7 +328,7 @@ class TestChatCompletionRequest:
...
@@ -328,7 +328,7 @@ class TestChatCompletionRequest:
request2
=
ChatCompletionRequest
(
request2
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
messages
,
tools
=
tools
model
=
"test-model"
,
messages
=
messages
,
tools
=
tools
)
)
assert
request2
.
tool_choice
==
"auto"
self
.
assert
Equal
(
request2
.
tool_choice
,
"auto"
)
def
test_chat_completion_sglang_extensions
(
self
):
def
test_chat_completion_sglang_extensions
(
self
):
"""Test chat completion with SGLang extensions"""
"""Test chat completion with SGLang extensions"""
...
@@ -342,14 +342,14 @@ class TestChatCompletionRequest:
...
@@ -342,14 +342,14 @@ class TestChatCompletionRequest:
stream_reasoning
=
False
,
stream_reasoning
=
False
,
chat_template_kwargs
=
{
"custom_param"
:
"value"
},
chat_template_kwargs
=
{
"custom_param"
:
"value"
},
)
)
assert
request
.
top_k
==
40
self
.
assert
Equal
(
request
.
top_k
,
40
)
assert
request
.
min_p
==
0.05
self
.
assert
Equal
(
request
.
min_p
,
0.05
)
assert
not
request
.
separate_reasoning
self
.
assert
False
(
request
.
separate_reasoning
)
assert
not
request
.
stream_reasoning
self
.
assert
False
(
request
.
stream_reasoning
)
assert
request
.
chat_template_kwargs
==
{
"custom_param"
:
"value"
}
self
.
assert
Equal
(
request
.
chat_template_kwargs
,
{
"custom_param"
:
"value"
}
)
class
TestChatCompletionResponse
:
class
TestChatCompletionResponse
(
unittest
.
TestCase
)
:
"""Test ChatCompletionResponse protocol model"""
"""Test ChatCompletionResponse protocol model"""
def
test_basic_chat_completion_response
(
self
):
def
test_basic_chat_completion_response
(
self
):
...
@@ -362,11 +362,11 @@ class TestChatCompletionResponse:
...
@@ -362,11 +362,11 @@ class TestChatCompletionResponse:
response
=
ChatCompletionResponse
(
response
=
ChatCompletionResponse
(
id
=
"test-id"
,
model
=
"test-model"
,
choices
=
[
choice
],
usage
=
usage
id
=
"test-id"
,
model
=
"test-model"
,
choices
=
[
choice
],
usage
=
usage
)
)
assert
response
.
id
==
"test-id"
self
.
assert
Equal
(
response
.
id
,
"test-id"
)
assert
response
.
object
==
"chat.completion"
self
.
assert
Equal
(
response
.
object
,
"chat.completion"
)
assert
response
.
model
==
"test-model"
self
.
assert
Equal
(
response
.
model
,
"test-model"
)
assert
len
(
response
.
choices
)
==
1
self
.
assert
Equal
(
len
(
response
.
choices
)
,
1
)
assert
response
.
choices
[
0
].
message
.
content
==
"Hello there!"
self
.
assert
Equal
(
response
.
choices
[
0
].
message
.
content
,
"Hello there!"
)
def
test_chat_completion_response_with_tool_calls
(
self
):
def
test_chat_completion_response_with_tool_calls
(
self
):
"""Test chat completion response with tool calls"""
"""Test chat completion response with tool calls"""
...
@@ -384,28 +384,30 @@ class TestChatCompletionResponse:
...
@@ -384,28 +384,30 @@ class TestChatCompletionResponse:
response
=
ChatCompletionResponse
(
response
=
ChatCompletionResponse
(
id
=
"test-id"
,
model
=
"test-model"
,
choices
=
[
choice
],
usage
=
usage
id
=
"test-id"
,
model
=
"test-model"
,
choices
=
[
choice
],
usage
=
usage
)
)
assert
response
.
choices
[
0
].
message
.
tool_calls
[
0
].
function
.
name
==
"get_weather"
self
.
assertEqual
(
assert
response
.
choices
[
0
].
finish_reason
==
"tool_calls"
response
.
choices
[
0
].
message
.
tool_calls
[
0
].
function
.
name
,
"get_weather"
)
self
.
assertEqual
(
response
.
choices
[
0
].
finish_reason
,
"tool_calls"
)
class
TestEmbeddingRequest
:
class
TestEmbeddingRequest
(
unittest
.
TestCase
)
:
"""Test EmbeddingRequest protocol model"""
"""Test EmbeddingRequest protocol model"""
def
test_basic_embedding_request
(
self
):
def
test_basic_embedding_request
(
self
):
"""Test basic embedding request"""
"""Test basic embedding request"""
request
=
EmbeddingRequest
(
model
=
"test-model"
,
input
=
"Hello world"
)
request
=
EmbeddingRequest
(
model
=
"test-model"
,
input
=
"Hello world"
)
assert
request
.
model
==
"test-model"
self
.
assert
Equal
(
request
.
model
,
"test-model"
)
assert
request
.
input
==
"Hello world"
self
.
assert
Equal
(
request
.
input
,
"Hello world"
)
assert
request
.
encoding_format
==
"float"
# default
self
.
assert
Equal
(
request
.
encoding_format
,
"float"
)
# default
assert
request
.
dimensions
is
None
# default
self
.
assert
IsNone
(
request
.
dimensions
)
# default
def
test_embedding_request_with_list_input
(
self
):
def
test_embedding_request_with_list_input
(
self
):
"""Test embedding request with list input"""
"""Test embedding request with list input"""
request
=
EmbeddingRequest
(
request
=
EmbeddingRequest
(
model
=
"test-model"
,
input
=
[
"Hello"
,
"world"
],
dimensions
=
512
model
=
"test-model"
,
input
=
[
"Hello"
,
"world"
],
dimensions
=
512
)
)
assert
request
.
input
==
[
"Hello"
,
"world"
]
self
.
assert
Equal
(
request
.
input
,
[
"Hello"
,
"world"
]
)
assert
request
.
dimensions
==
512
self
.
assert
Equal
(
request
.
dimensions
,
512
)
def
test_multimodal_embedding_request
(
self
):
def
test_multimodal_embedding_request
(
self
):
"""Test multimodal embedding request"""
"""Test multimodal embedding request"""
...
@@ -414,14 +416,14 @@ class TestEmbeddingRequest:
...
@@ -414,14 +416,14 @@ class TestEmbeddingRequest:
MultimodalEmbeddingInput
(
text
=
"World"
,
image
=
None
),
MultimodalEmbeddingInput
(
text
=
"World"
,
image
=
None
),
]
]
request
=
EmbeddingRequest
(
model
=
"test-model"
,
input
=
multimodal_input
)
request
=
EmbeddingRequest
(
model
=
"test-model"
,
input
=
multimodal_input
)
assert
len
(
request
.
input
)
==
2
self
.
assert
Equal
(
len
(
request
.
input
)
,
2
)
assert
request
.
input
[
0
].
text
==
"Hello"
self
.
assert
Equal
(
request
.
input
[
0
].
text
,
"Hello"
)
assert
request
.
input
[
0
].
image
==
"base64_image_data"
self
.
assert
Equal
(
request
.
input
[
0
].
image
,
"base64_image_data"
)
assert
request
.
input
[
1
].
text
==
"World"
self
.
assert
Equal
(
request
.
input
[
1
].
text
,
"World"
)
assert
request
.
input
[
1
].
image
is
None
self
.
assert
IsNone
(
request
.
input
[
1
].
image
)
class
TestEmbeddingResponse
:
class
TestEmbeddingResponse
(
unittest
.
TestCase
)
:
"""Test EmbeddingResponse protocol model"""
"""Test EmbeddingResponse protocol model"""
def
test_basic_embedding_response
(
self
):
def
test_basic_embedding_response
(
self
):
...
@@ -431,14 +433,14 @@ class TestEmbeddingResponse:
...
@@ -431,14 +433,14 @@ class TestEmbeddingResponse:
response
=
EmbeddingResponse
(
response
=
EmbeddingResponse
(
data
=
[
embedding_obj
],
model
=
"test-model"
,
usage
=
usage
data
=
[
embedding_obj
],
model
=
"test-model"
,
usage
=
usage
)
)
assert
response
.
object
==
"list"
self
.
assert
Equal
(
response
.
object
,
"list"
)
assert
len
(
response
.
data
)
==
1
self
.
assert
Equal
(
len
(
response
.
data
)
,
1
)
assert
response
.
data
[
0
].
embedding
==
[
0.1
,
0.2
,
0.3
]
self
.
assert
Equal
(
response
.
data
[
0
].
embedding
,
[
0.1
,
0.2
,
0.3
]
)
assert
response
.
data
[
0
].
index
==
0
self
.
assert
Equal
(
response
.
data
[
0
].
index
,
0
)
assert
response
.
usage
.
prompt_tokens
==
3
self
.
assert
Equal
(
response
.
usage
.
prompt_tokens
,
3
)
class
TestScoringRequest
:
class
TestScoringRequest
(
unittest
.
TestCase
)
:
"""Test ScoringRequest protocol model"""
"""Test ScoringRequest protocol model"""
def
test_basic_scoring_request
(
self
):
def
test_basic_scoring_request
(
self
):
...
@@ -446,11 +448,11 @@ class TestScoringRequest:
...
@@ -446,11 +448,11 @@ class TestScoringRequest:
request
=
ScoringRequest
(
request
=
ScoringRequest
(
model
=
"test-model"
,
query
=
"Hello"
,
items
=
[
"World"
,
"Earth"
]
model
=
"test-model"
,
query
=
"Hello"
,
items
=
[
"World"
,
"Earth"
]
)
)
assert
request
.
model
==
"test-model"
self
.
assert
Equal
(
request
.
model
,
"test-model"
)
assert
request
.
query
==
"Hello"
self
.
assert
Equal
(
request
.
query
,
"Hello"
)
assert
request
.
items
==
[
"World"
,
"Earth"
]
self
.
assert
Equal
(
request
.
items
,
[
"World"
,
"Earth"
]
)
assert
not
request
.
apply_softmax
# default
self
.
assert
False
(
request
.
apply_softmax
)
# default
assert
not
request
.
item_first
# default
self
.
assert
False
(
request
.
item_first
)
# default
def
test_scoring_request_with_token_ids
(
self
):
def
test_scoring_request_with_token_ids
(
self
):
"""Test scoring request with token IDs"""
"""Test scoring request with token IDs"""
...
@@ -462,34 +464,34 @@ class TestScoringRequest:
...
@@ -462,34 +464,34 @@ class TestScoringRequest:
apply_softmax
=
True
,
apply_softmax
=
True
,
item_first
=
True
,
item_first
=
True
,
)
)
assert
request
.
query
==
[
1
,
2
,
3
]
self
.
assert
Equal
(
request
.
query
,
[
1
,
2
,
3
]
)
assert
request
.
items
==
[[
4
,
5
],
[
6
,
7
]]
self
.
assert
Equal
(
request
.
items
,
[[
4
,
5
],
[
6
,
7
]]
)
assert
request
.
label_token_ids
==
[
8
,
9
]
self
.
assert
Equal
(
request
.
label_token_ids
,
[
8
,
9
]
)
assert
request
.
apply_softmax
self
.
assert
True
(
request
.
apply_softmax
)
assert
request
.
item_first
self
.
assert
True
(
request
.
item_first
)
class
TestScoringResponse
:
class
TestScoringResponse
(
unittest
.
TestCase
)
:
"""Test ScoringResponse protocol model"""
"""Test ScoringResponse protocol model"""
def
test_basic_scoring_response
(
self
):
def
test_basic_scoring_response
(
self
):
"""Test basic scoring response"""
"""Test basic scoring response"""
response
=
ScoringResponse
(
scores
=
[[
0.1
,
0.9
],
[
0.3
,
0.7
]],
model
=
"test-model"
)
response
=
ScoringResponse
(
scores
=
[[
0.1
,
0.9
],
[
0.3
,
0.7
]],
model
=
"test-model"
)
assert
response
.
object
==
"scoring"
self
.
assert
Equal
(
response
.
object
,
"scoring"
)
assert
response
.
scores
==
[[
0.1
,
0.9
],
[
0.3
,
0.7
]]
self
.
assert
Equal
(
response
.
scores
,
[[
0.1
,
0.9
],
[
0.3
,
0.7
]]
)
assert
response
.
model
==
"test-model"
self
.
assert
Equal
(
response
.
model
,
"test-model"
)
assert
response
.
usage
is
None
# default
self
.
assert
IsNone
(
response
.
usage
)
# default
class
TestFileOperations
:
class
TestFileOperations
(
unittest
.
TestCase
)
:
"""Test file operation protocol models"""
"""Test file operation protocol models"""
def
test_file_request
(
self
):
def
test_file_request
(
self
):
"""Test file request model"""
"""Test file request model"""
file_data
=
b
"test file content"
file_data
=
b
"test file content"
request
=
FileRequest
(
file
=
file_data
,
purpose
=
"batch"
)
request
=
FileRequest
(
file
=
file_data
,
purpose
=
"batch"
)
assert
request
.
file
==
file_data
self
.
assert
Equal
(
request
.
file
,
file_data
)
assert
request
.
purpose
==
"batch"
self
.
assert
Equal
(
request
.
purpose
,
"batch"
)
def
test_file_response
(
self
):
def
test_file_response
(
self
):
"""Test file response model"""
"""Test file response model"""
...
@@ -500,20 +502,20 @@ class TestFileOperations:
...
@@ -500,20 +502,20 @@ class TestFileOperations:
filename
=
"test.jsonl"
,
filename
=
"test.jsonl"
,
purpose
=
"batch"
,
purpose
=
"batch"
,
)
)
assert
response
.
id
==
"file-123"
self
.
assert
Equal
(
response
.
id
,
"file-123"
)
assert
response
.
object
==
"file"
self
.
assert
Equal
(
response
.
object
,
"file"
)
assert
response
.
bytes
==
1024
self
.
assert
Equal
(
response
.
bytes
,
1024
)
assert
response
.
filename
==
"test.jsonl"
self
.
assert
Equal
(
response
.
filename
,
"test.jsonl"
)
def
test_file_delete_response
(
self
):
def
test_file_delete_response
(
self
):
"""Test file delete response model"""
"""Test file delete response model"""
response
=
FileDeleteResponse
(
id
=
"file-123"
,
deleted
=
True
)
response
=
FileDeleteResponse
(
id
=
"file-123"
,
deleted
=
True
)
assert
response
.
id
==
"file-123"
self
.
assert
Equal
(
response
.
id
,
"file-123"
)
assert
response
.
object
==
"file"
self
.
assert
Equal
(
response
.
object
,
"file"
)
assert
response
.
deleted
self
.
assert
True
(
response
.
deleted
)
class
TestBatchOperations
:
class
TestBatchOperations
(
unittest
.
TestCase
)
:
"""Test batch operation protocol models"""
"""Test batch operation protocol models"""
def
test_batch_request
(
self
):
def
test_batch_request
(
self
):
...
@@ -524,10 +526,10 @@ class TestBatchOperations:
...
@@ -524,10 +526,10 @@ class TestBatchOperations:
completion_window
=
"24h"
,
completion_window
=
"24h"
,
metadata
=
{
"custom"
:
"value"
},
metadata
=
{
"custom"
:
"value"
},
)
)
assert
request
.
input_file_id
==
"file-123"
self
.
assert
Equal
(
request
.
input_file_id
,
"file-123"
)
assert
request
.
endpoint
==
"/v1/chat/completions"
self
.
assert
Equal
(
request
.
endpoint
,
"/v1/chat/completions"
)
assert
request
.
completion_window
==
"24h"
self
.
assert
Equal
(
request
.
completion_window
,
"24h"
)
assert
request
.
metadata
==
{
"custom"
:
"value"
}
self
.
assert
Equal
(
request
.
metadata
,
{
"custom"
:
"value"
}
)
def
test_batch_response
(
self
):
def
test_batch_response
(
self
):
"""Test batch response model"""
"""Test batch response model"""
...
@@ -538,20 +540,20 @@ class TestBatchOperations:
...
@@ -538,20 +540,20 @@ class TestBatchOperations:
completion_window
=
"24h"
,
completion_window
=
"24h"
,
created_at
=
1234567890
,
created_at
=
1234567890
,
)
)
assert
response
.
id
==
"batch-123"
self
.
assert
Equal
(
response
.
id
,
"batch-123"
)
assert
response
.
object
==
"batch"
self
.
assert
Equal
(
response
.
object
,
"batch"
)
assert
response
.
status
==
"validating"
# default
self
.
assert
Equal
(
response
.
status
,
"validating"
)
# default
assert
response
.
endpoint
==
"/v1/chat/completions"
self
.
assert
Equal
(
response
.
endpoint
,
"/v1/chat/completions"
)
class
TestResponseFormats
:
class
TestResponseFormats
(
unittest
.
TestCase
)
:
"""Test response format protocol models"""
"""Test response format protocol models"""
def
test_basic_response_format
(
self
):
def
test_basic_response_format
(
self
):
"""Test basic response format"""
"""Test basic response format"""
format_obj
=
ResponseFormat
(
type
=
"json_object"
)
format_obj
=
ResponseFormat
(
type
=
"json_object"
)
assert
format_obj
.
type
==
"json_object"
self
.
assert
Equal
(
format_obj
.
type
,
"json_object"
)
assert
format_obj
.
json_schema
is
None
self
.
assert
IsNone
(
format_obj
.
json_schema
)
def
test_json_schema_response_format
(
self
):
def
test_json_schema_response_format
(
self
):
"""Test JSON schema response format"""
"""Test JSON schema response format"""
...
@@ -560,9 +562,9 @@ class TestResponseFormats:
...
@@ -560,9 +562,9 @@ class TestResponseFormats:
name
=
"person_schema"
,
description
=
"Person schema"
,
schema
=
schema
name
=
"person_schema"
,
description
=
"Person schema"
,
schema
=
schema
)
)
format_obj
=
ResponseFormat
(
type
=
"json_schema"
,
json_schema
=
json_schema
)
format_obj
=
ResponseFormat
(
type
=
"json_schema"
,
json_schema
=
json_schema
)
assert
format_obj
.
type
==
"json_schema"
self
.
assert
Equal
(
format_obj
.
type
,
"json_schema"
)
assert
format_obj
.
json_schema
.
name
==
"person_schema"
self
.
assert
Equal
(
format_obj
.
json_schema
.
name
,
"person_schema"
)
assert
format_obj
.
json_schema
.
schema_
==
schema
self
.
assert
Equal
(
format_obj
.
json_schema
.
schema_
,
schema
)
def
test_structural_tag_response_format
(
self
):
def
test_structural_tag_response_format
(
self
):
"""Test structural tag response format"""
"""Test structural tag response format"""
...
@@ -576,12 +578,12 @@ class TestResponseFormats:
...
@@ -576,12 +578,12 @@ class TestResponseFormats:
format_obj
=
StructuralTagResponseFormat
(
format_obj
=
StructuralTagResponseFormat
(
type
=
"structural_tag"
,
structures
=
structures
,
triggers
=
[
"think"
]
type
=
"structural_tag"
,
structures
=
structures
,
triggers
=
[
"think"
]
)
)
assert
format_obj
.
type
==
"structural_tag"
self
.
assert
Equal
(
format_obj
.
type
,
"structural_tag"
)
assert
len
(
format_obj
.
structures
)
==
1
self
.
assert
Equal
(
len
(
format_obj
.
structures
)
,
1
)
assert
format_obj
.
triggers
==
[
"think"
]
self
.
assert
Equal
(
format_obj
.
triggers
,
[
"think"
]
)
class
TestLogProbs
:
class
TestLogProbs
(
unittest
.
TestCase
)
:
"""Test LogProbs protocol models"""
"""Test LogProbs protocol models"""
def
test_basic_logprobs
(
self
):
def
test_basic_logprobs
(
self
):
...
@@ -592,9 +594,9 @@ class TestLogProbs:
...
@@ -592,9 +594,9 @@ class TestLogProbs:
tokens
=
[
"Hello"
,
" "
,
"world"
],
tokens
=
[
"Hello"
,
" "
,
"world"
],
top_logprobs
=
[{
"Hello"
:
-
0.1
},
{
" "
:
-
0.2
},
{
"world"
:
-
0.3
}],
top_logprobs
=
[{
"Hello"
:
-
0.1
},
{
" "
:
-
0.2
},
{
"world"
:
-
0.3
}],
)
)
assert
len
(
logprobs
.
tokens
)
==
3
self
.
assert
Equal
(
len
(
logprobs
.
tokens
)
,
3
)
assert
logprobs
.
tokens
==
[
"Hello"
,
" "
,
"world"
]
self
.
assert
Equal
(
logprobs
.
tokens
,
[
"Hello"
,
" "
,
"world"
]
)
assert
logprobs
.
token_logprobs
==
[
-
0.1
,
-
0.2
,
-
0.3
]
self
.
assert
Equal
(
logprobs
.
token_logprobs
,
[
-
0.1
,
-
0.2
,
-
0.3
]
)
def
test_choice_logprobs
(
self
):
def
test_choice_logprobs
(
self
):
"""Test ChoiceLogprobs model"""
"""Test ChoiceLogprobs model"""
...
@@ -607,17 +609,17 @@ class TestLogProbs:
...
@@ -607,17 +609,17 @@ class TestLogProbs:
],
],
)
)
choice_logprobs
=
ChoiceLogprobs
(
content
=
[
token_logprob
])
choice_logprobs
=
ChoiceLogprobs
(
content
=
[
token_logprob
])
assert
len
(
choice_logprobs
.
content
)
==
1
self
.
assert
Equal
(
len
(
choice_logprobs
.
content
)
,
1
)
assert
choice_logprobs
.
content
[
0
].
token
==
"Hello"
self
.
assert
Equal
(
choice_logprobs
.
content
[
0
].
token
,
"Hello"
)
class
TestStreamingModels
:
class
TestStreamingModels
(
unittest
.
TestCase
)
:
"""Test streaming response models"""
"""Test streaming response models"""
def
test_stream_options
(
self
):
def
test_stream_options
(
self
):
"""Test StreamOptions model"""
"""Test StreamOptions model"""
options
=
StreamOptions
(
include_usage
=
True
)
options
=
StreamOptions
(
include_usage
=
True
)
assert
options
.
include_usage
self
.
assert
True
(
options
.
include_usage
)
def
test_chat_completion_stream_response
(
self
):
def
test_chat_completion_stream_response
(
self
):
"""Test ChatCompletionStreamResponse model"""
"""Test ChatCompletionStreamResponse model"""
...
@@ -626,29 +628,29 @@ class TestStreamingModels:
...
@@ -626,29 +628,29 @@ class TestStreamingModels:
response
=
ChatCompletionStreamResponse
(
response
=
ChatCompletionStreamResponse
(
id
=
"test-id"
,
model
=
"test-model"
,
choices
=
[
choice
]
id
=
"test-id"
,
model
=
"test-model"
,
choices
=
[
choice
]
)
)
assert
response
.
object
==
"chat.completion.chunk"
self
.
assert
Equal
(
response
.
object
,
"chat.completion.chunk"
)
assert
response
.
choices
[
0
].
delta
.
content
==
"Hello"
self
.
assert
Equal
(
response
.
choices
[
0
].
delta
.
content
,
"Hello"
)
class
TestValidationEdgeCases
:
class
TestValidationEdgeCases
(
unittest
.
TestCase
)
:
"""Test edge cases and validation scenarios"""
"""Test edge cases and validation scenarios"""
def
test_empty_messages_validation
(
self
):
def
test_empty_messages_validation
(
self
):
"""Test validation with empty messages"""
"""Test validation with empty messages"""
with
pytest
.
r
aises
(
ValidationError
):
with
self
.
assertR
aises
(
ValidationError
):
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[])
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[])
def
test_invalid_tool_choice_type
(
self
):
def
test_invalid_tool_choice_type
(
self
):
"""Test invalid tool choice type"""
"""Test invalid tool choice type"""
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello"
}]
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello"
}]
with
pytest
.
r
aises
(
ValidationError
):
with
self
.
assertR
aises
(
ValidationError
):
ChatCompletionRequest
(
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
messages
,
tool_choice
=
123
model
=
"test-model"
,
messages
=
messages
,
tool_choice
=
123
)
)
def
test_negative_token_limits
(
self
):
def
test_negative_token_limits
(
self
):
"""Test negative token limits"""
"""Test negative token limits"""
with
pytest
.
r
aises
(
ValidationError
):
with
self
.
assertR
aises
(
ValidationError
):
CompletionRequest
(
model
=
"test-model"
,
prompt
=
"Hello"
,
max_tokens
=-
1
)
CompletionRequest
(
model
=
"test-model"
,
prompt
=
"Hello"
,
max_tokens
=-
1
)
def
test_invalid_temperature_range
(
self
):
def
test_invalid_temperature_range
(
self
):
...
@@ -656,7 +658,7 @@ class TestValidationEdgeCases:
...
@@ -656,7 +658,7 @@ class TestValidationEdgeCases:
# Note: The current protocol doesn't enforce temperature range,
# Note: The current protocol doesn't enforce temperature range,
# but this test documents expected behavior
# but this test documents expected behavior
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
"Hello"
,
temperature
=
5.0
)
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
"Hello"
,
temperature
=
5.0
)
assert
request
.
temperature
==
5.0
# Currently allowed
self
.
assert
Equal
(
request
.
temperature
,
5.0
)
# Currently allowed
def
test_model_serialization_roundtrip
(
self
):
def
test_model_serialization_roundtrip
(
self
):
"""Test that models can be serialized and deserialized"""
"""Test that models can be serialized and deserialized"""
...
@@ -673,11 +675,11 @@ class TestValidationEdgeCases:
...
@@ -673,11 +675,11 @@ class TestValidationEdgeCases:
# Deserialize back
# Deserialize back
restored_request
=
ChatCompletionRequest
(
**
data
)
restored_request
=
ChatCompletionRequest
(
**
data
)
assert
restored_request
.
model
==
original_request
.
model
self
.
assert
Equal
(
restored_request
.
model
,
original_request
.
model
)
assert
restored_request
.
temperature
==
original_request
.
temperature
self
.
assert
Equal
(
restored_request
.
temperature
,
original_request
.
temperature
)
assert
restored_request
.
max_tokens
==
original_request
.
max_tokens
self
.
assert
Equal
(
restored_request
.
max_tokens
,
original_request
.
max_tokens
)
assert
len
(
restored_request
.
messages
)
==
len
(
original_request
.
messages
)
self
.
assert
Equal
(
len
(
restored_request
.
messages
)
,
len
(
original_request
.
messages
)
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
py
test
.
main
(
[
__file__
]
)
unit
test
.
main
(
verbosity
=
2
)
test/srt/openai/test_server.py
View file @
ffd1a26e
# sglang/test/srt/openai/test_server.py
# sglang/test/srt/openai/test_server.py
import
pytest
import
requests
import
requests
from
sglang.test.test_utils
import
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
as
MODEL_ID
def
test_health
(
openai_server
:
str
):
def
test_health
(
openai_server
:
str
):
r
=
requests
.
get
(
f
"
{
openai_server
}
/health"
)
r
=
requests
.
get
(
f
"
{
openai_server
}
/health"
)
assert
r
.
status_code
==
200
,
r
.
text
assert
r
.
status_code
==
200
# FastAPI returns an empty body → r.text == ""
assert
r
.
text
==
""
assert
r
.
text
==
""
@
pytest
.
mark
.
xfail
(
reason
=
"Endpoint skeleton not implemented yet"
)
def
test_models_endpoint
(
openai_server
:
str
):
def
test_models_endpoint
(
openai_server
:
str
):
r
=
requests
.
get
(
f
"
{
openai_server
}
/v1/models"
)
r
=
requests
.
get
(
f
"
{
openai_server
}
/v1/models"
)
# once implemented this should be 200
assert
r
.
status_code
==
200
,
r
.
text
assert
r
.
status_code
==
200
payload
=
r
.
json
()
# Basic contract
assert
"data"
in
payload
and
isinstance
(
payload
[
"data"
],
list
)
and
payload
[
"data"
]
# Validate fields of the first model card
first
=
payload
[
"data"
][
0
]
for
key
in
(
"id"
,
"root"
,
"max_model_len"
):
assert
key
in
first
,
f
"missing
{
key
}
in
{
first
}
"
# max_model_len must be positive
assert
isinstance
(
first
[
"max_model_len"
],
int
)
and
first
[
"max_model_len"
]
>
0
# The server should report the same model id we launched it with
ids
=
{
m
[
"id"
]
for
m
in
payload
[
"data"
]}
assert
MODEL_ID
in
ids
def
test_get_model_info
(
openai_server
:
str
):
r
=
requests
.
get
(
f
"
{
openai_server
}
/get_model_info"
)
assert
r
.
status_code
==
200
,
r
.
text
info
=
r
.
json
()
expected_keys
=
{
"model_path"
,
"tokenizer_path"
,
"is_generation"
}
assert
expected_keys
.
issubset
(
info
.
keys
())
# model_path must end with the one we passed on the CLI
assert
info
[
"model_path"
].
endswith
(
MODEL_ID
)
# is_generation is documented as a boolean
assert
isinstance
(
info
[
"is_generation"
],
bool
)
def
test_unknown_route_returns_404
(
openai_server
:
str
):
r
=
requests
.
get
(
f
"
{
openai_server
}
/definitely-not-a-real-route"
)
assert
r
.
status_code
==
404
test/srt/openai/test_serving_chat.py
View file @
ffd1a26e
"""
"""
Unit tests for the OpenAIServingChat class from serving_chat.py.
Unit-tests for OpenAIServingChat — rewritten to use only the std-lib 'unittest'.
Run with either:
These tests ensure that the refactored implementation maintains compatibility
python tests/test_serving_chat_unit.py -v
with the original adapter.py functionality.
or
python -m unittest discover -s tests -p "test_*unit.py" -v
"""
"""
import
unittest
import
uuid
import
uuid
from
typing
import
Optional
from
unittest.mock
import
Mock
,
patch
from
unittest.mock
import
Mock
,
patch
import
pytest
from
fastapi
import
Request
from
fastapi
import
Request
from
sglang.srt.entrypoints.openai.protocol
import
ChatCompletionRequest
,
ErrorResponse
from
sglang.srt.entrypoints.openai.protocol
import
ChatCompletionRequest
from
sglang.srt.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
sglang.srt.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.managers.io_struct
import
GenerateReqInput
# Mock TokenizerManager since it may not be directly importable in tests
class
_MockTokenizerManager
:
class
MockTokenizerManager
:
"""Minimal mock that satisfies OpenAIServingChat."""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
model_config
=
Mock
()
self
.
model_config
=
Mock
(
is_multimodal
=
False
)
self
.
model_config
.
is_multimodal
=
False
self
.
server_args
=
Mock
(
self
.
server_args
=
Mock
()
enable_cache_report
=
False
,
self
.
server_args
.
enable_cache_report
=
False
tool_call_parser
=
"hermes"
,
self
.
server_args
.
tool_call_parser
=
"hermes"
reasoning_parser
=
None
,
self
.
server_args
.
reasoning_parser
=
None
)
self
.
chat_template_name
=
"llama-3"
self
.
chat_template_name
:
Optional
[
str
]
=
"llama-3"
#
Mock
tokenizer
# tokenizer
stub
self
.
tokenizer
=
Mock
()
self
.
tokenizer
=
Mock
()
self
.
tokenizer
.
encode
=
Mock
(
return_value
=
[
1
,
2
,
3
,
4
,
5
]
)
self
.
tokenizer
.
encode
.
return_value
=
[
1
,
2
,
3
,
4
,
5
]
self
.
tokenizer
.
decode
=
Mock
(
return_value
=
"Test response"
)
self
.
tokenizer
.
decode
.
return_value
=
"Test response"
self
.
tokenizer
.
chat_template
=
None
self
.
tokenizer
.
chat_template
=
None
self
.
tokenizer
.
bos_token_id
=
1
self
.
tokenizer
.
bos_token_id
=
1
#
Mock
generate_request
method
#
async generator stub for
generate_request
async
def
mock_generate
():
async
def
_
mock_generate
():
yield
{
yield
{
"text"
:
"Test response"
,
"text"
:
"Test response"
,
"meta_info"
:
{
"meta_info"
:
{
...
@@ -50,585 +53,176 @@ class MockTokenizerManager:
...
@@ -50,585 +53,176 @@ class MockTokenizerManager:
"index"
:
0
,
"index"
:
0
,
}
}
self
.
generate_request
=
Mock
(
return_value
=
mock_generate
())
self
.
generate_request
=
Mock
(
return_value
=
_mock_generate
())
self
.
create_abort_task
=
Mock
(
return_value
=
None
)
self
.
create_abort_task
=
Mock
()
@
pytest
.
fixture
def
mock_tokenizer_manager
():
"""Create a mock tokenizer manager for testing."""
return
MockTokenizerManager
()
@
pytest
.
fixture
def
serving_chat
(
mock_tokenizer_manager
):
"""Create a OpenAIServingChat instance for testing."""
return
OpenAIServingChat
(
mock_tokenizer_manager
)
class
ServingChatTestCase
(
unittest
.
TestCase
):
# ------------- common fixtures -------------
def
setUp
(
self
):
self
.
tm
=
_MockTokenizerManager
()
self
.
chat
=
OpenAIServingChat
(
self
.
tm
)
@
pytest
.
fixture
# frequently reused requests
def
mock_request
():
self
.
basic_req
=
ChatCompletionRequest
(
"""Create a mock FastAPI request."""
model
=
"x"
,
request
=
Mock
(
spec
=
Request
)
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hi?"
}],
request
.
headers
=
{}
temperature
=
0.7
,
return
request
max_tokens
=
100
,
stream
=
False
,
)
@
pytest
.
fixture
self
.
stream_req
=
ChatCompletionRequest
(
def
basic_chat_request
():
model
=
"x"
,
"""Create a basic chat completion request."""
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hi?"
}],
return
ChatCompletionRequest
(
temperature
=
0.7
,
model
=
"test-model"
,
max_tokens
=
100
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello, how are you?"
}],
stream
=
True
,
temperature
=
0.7
,
)
max_tokens
=
100
,
stream
=
False
,
)
@
pytest
.
fixture
def
streaming_chat_request
():
"""Create a streaming chat completion request."""
return
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello, how are you?"
}],
temperature
=
0.7
,
max_tokens
=
100
,
stream
=
True
,
)
class
TestOpenAIServingChatConversion
:
self
.
fastapi_request
=
Mock
(
spec
=
Request
)
"""Test request conversion methods."""
self
.
fastapi_request
.
headers
=
{}
def
test_convert_to_internal_request_single
(
# ------------- conversion tests -------------
self
,
serving_chat
,
basic_chat_request
,
mock_tokenizer_manager
def
test_convert_to_internal_request_single
(
self
):
):
"""Test converting single request to internal format."""
with
patch
(
with
patch
(
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
)
as
mock_conv
:
)
as
conv_mock
,
patch
.
object
(
self
.
chat
,
"_process_messages"
)
as
proc_mock
:
mock_conv_instance
=
Mock
()
conv_ins
=
Mock
()
mock_conv_instance
.
get_prompt
.
return_value
=
"Test prompt"
conv_ins
.
get_prompt
.
return_value
=
"Test prompt"
mock_conv_instance
.
image_data
=
None
conv_ins
.
image_data
=
conv_ins
.
audio_data
=
None
mock_conv_instance
.
audio_data
=
None
conv_ins
.
modalities
=
[]
mock_conv_instance
.
modalities
=
[]
conv_ins
.
stop_str
=
[
"</s>"
]
mock_conv_instance
.
stop_str
=
[
"</s>"
]
conv_mock
.
return_value
=
conv_ins
mock_conv
.
return_value
=
mock_conv_instance
proc_mock
.
return_value
=
(
# Mock the _process_messages method to return expected values
"Test prompt"
,
with
patch
.
object
(
serving_chat
,
"_process_messages"
)
as
mock_process
:
[
1
,
2
,
3
],
mock_process
.
return_value
=
(
None
,
"Test prompt"
,
None
,
[
1
,
2
,
3
],
[],
None
,
[
"</s>"
],
None
,
None
,
[],
)
[
"</s>"
],
None
,
# tool_call_constraint
)
adapted_request
,
processed_request
=
(
serving_chat
.
_convert_to_internal_request
(
[
basic_chat_request
],
[
"test-id"
]
)
)
assert
isinstance
(
adapted_request
,
GenerateReqInput
)
assert
adapted_request
.
stream
==
basic_chat_request
.
stream
assert
processed_request
==
basic_chat_request
class
TestToolCalls
:
adapted
,
processed
=
self
.
chat
.
_convert_to_internal_request
(
"""Test tool call functionality from adapter.py"""
[
self
.
basic_req
],
[
"rid"
]
)
self
.
assertIsInstance
(
adapted
,
GenerateReqInput
)
self
.
assertFalse
(
adapted
.
stream
)
self
.
assertEqual
(
processed
,
self
.
basic_req
)
def
test_
tool
_
call
_request_conversion
(
self
,
serving_chat
):
# -------------
tool
-
call
branch -------------
"""Test request with
tool
call
s"""
def
test_
tool
_
call
_request_conversion
(
self
):
req
uest
=
ChatCompletionRequest
(
req
=
ChatCompletionRequest
(
model
=
"
test-model
"
,
model
=
"
x
"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"W
hat's the w
eather?"
}],
messages
=
[{
"role"
:
"user"
,
"content"
:
"Weather?"
}],
tools
=
[
tools
=
[
{
{
"type"
:
"function"
,
"type"
:
"function"
,
"function"
:
{
"function"
:
{
"name"
:
"get_weather"
,
"name"
:
"get_weather"
,
"description"
:
"Get weather information"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{}},
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"location"
:
{
"type"
:
"string"
}},
},
},
},
}
}
],
],
tool_choice
=
"auto"
,
tool_choice
=
"auto"
,
)
)
with
patch
.
object
(
serving_chat
,
"_process_messages"
)
as
mock_process
:
with
patch
.
object
(
mock_process
.
return_value
=
(
self
.
chat
,
"Test prompt"
,
"_process_messages"
,
[
1
,
2
,
3
],
return_value
=
(
"Prompt"
,
[
1
,
2
,
3
],
None
,
None
,
[],
[
"</s>"
],
None
),
None
,
):
None
,
adapted
,
_
=
self
.
chat
.
_convert_to_internal_request
([
req
],
[
"rid"
])
[],
self
.
assertEqual
(
adapted
.
rid
,
"rid"
)
[
"</s>"
],
None
,
# tool_call_constraint
def
test_tool_choice_none
(
self
):
)
req
=
ChatCompletionRequest
(
model
=
"x"
,
adapted_request
,
_
=
serving_chat
.
_convert_to_internal_request
(
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hi"
}],
[
request
],
[
"test-id"
]
tools
=
[{
"type"
:
"function"
,
"function"
:
{
"name"
:
"noop"
}}],
)
assert
adapted_request
.
rid
==
"test-id"
# Tool call constraint should be processed
assert
request
.
tools
is
not
None
def
test_tool_choice_none
(
self
,
serving_chat
):
"""Test tool_choice=none disables tool calls"""
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello"
}],
tools
=
[{
"type"
:
"function"
,
"function"
:
{
"name"
:
"test_func"
}}],
tool_choice
=
"none"
,
tool_choice
=
"none"
,
)
)
with
patch
.
object
(
with
patch
.
object
(
serving_chat
,
"_process_messages"
)
as
mock_process
:
self
.
chat
,
mock_process
.
return_value
=
(
"_process_messages"
,
"Test prompt"
,
return_value
=
(
"Prompt"
,
[
1
,
2
,
3
],
None
,
None
,
[],
[
"</s>"
],
None
),
[
1
,
2
,
3
],
):
None
,
adapted
,
_
=
self
.
chat
.
_convert_to_internal_request
([
req
],
[
"rid"
])
None
,
self
.
assertEqual
(
adapted
.
rid
,
"rid"
)
[],
[
"</s>"
],
# ------------- multimodal branch -------------
None
,
# tool_call_constraint
def
test_multimodal_request_with_images
(
self
):
)
self
.
tm
.
model_config
.
is_multimodal
=
True
adapted_request
,
_
=
serving_chat
.
_convert_to_internal_request
(
req
=
ChatCompletionRequest
(
[
request
],
[
"test-id"
]
model
=
"x"
,
)
# Tools should not be processed when tool_choice is "none"
assert
adapted_request
.
rid
==
"test-id"
def
test_tool_call_response_processing
(
self
,
serving_chat
):
"""Test processing tool calls in response"""
mock_ret_item
=
{
"text"
:
'{"name": "get_weather", "parameters": {"location": "Paris"}}'
,
"meta_info"
:
{
"output_token_logprobs"
:
[],
"output_top_logprobs"
:
None
,
},
}
tools
=
[
{
"type"
:
"function"
,
"function"
:
{
"name"
:
"get_weather"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"location"
:
{
"type"
:
"string"
}},
},
},
}
]
finish_reason
=
{
"type"
:
"stop"
,
"matched"
:
None
}
# Mock FunctionCallParser
with
patch
(
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
)
as
mock_parser_class
:
mock_parser
=
Mock
()
mock_parser
.
has_tool_call
.
return_value
=
True
# Create proper mock tool call object
mock_tool_call
=
Mock
()
mock_tool_call
.
name
=
"get_weather"
mock_tool_call
.
parameters
=
'{"location": "Paris"}'
mock_parser
.
parse_non_stream
.
return_value
=
(
""
,
[
mock_tool_call
])
mock_parser_class
.
return_value
=
mock_parser
tool_calls
,
text
,
updated_finish_reason
=
serving_chat
.
_process_tool_calls
(
mock_ret_item
[
"text"
],
tools
,
"hermes"
,
finish_reason
)
assert
tool_calls
is
not
None
assert
len
(
tool_calls
)
==
1
assert
updated_finish_reason
[
"type"
]
==
"tool_calls"
class
TestMultimodalContent
:
"""Test multimodal content handling from adapter.py"""
def
test_multimodal_request_with_images
(
self
,
serving_chat
):
"""Test request with image content"""
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[
messages
=
[
{
{
"role"
:
"user"
,
"role"
:
"user"
,
"content"
:
[
"content"
:
[
{
"type"
:
"text"
,
"text"
:
"What's in th
is
image?"
},
{
"type"
:
"text"
,
"text"
:
"What's in th
e
image?"
},
{
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
"data:image/jpeg;base64,
...
"
},
"image_url"
:
{
"url"
:
"data:image/jpeg;base64,"
},
},
},
],
],
}
}
],
],
)
)
# Set multimodal mode
with
patch
.
object
(
serving_chat
.
tokenizer_manager
.
model_config
.
is_multimodal
=
True
self
.
chat
,
"_apply_jinja_template"
,
with
patch
.
object
(
serving_chat
,
"_apply_jinja_template"
)
as
mock_apply
:
return_value
=
(
"prompt"
,
[
1
,
2
],
[
"img"
],
None
,
[],
[]),
mock_apply
.
return_value
=
(
),
patch
.
object
(
"prompt"
,
self
.
chat
,
[
1
,
2
,
3
],
"_apply_conversation_template"
,
[
"image_data"
],
return_value
=
(
"prompt"
,
[
"img"
],
None
,
[],
[]),
None
,
):
[],
out
=
self
.
chat
.
_process_messages
(
req
,
True
)
[],
_
,
_
,
image_data
,
*
_
=
out
)
self
.
assertEqual
(
image_data
,
[
"img"
])
with
patch
.
object
(
# ------------- template handling -------------
serving_chat
,
"_apply_conversation_template"
def
test_jinja_template_processing
(
self
):
)
as
mock_conv
:
req
=
ChatCompletionRequest
(
mock_conv
.
return_value
=
(
"prompt"
,
[
"image_data"
],
None
,
[],
[])
model
=
"x"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello"
}]
(
prompt
,
prompt_ids
,
image_data
,
audio_data
,
modalities
,
stop
,
tool_call_constraint
,
)
=
serving_chat
.
_process_messages
(
request
,
True
)
assert
image_data
==
[
"image_data"
]
assert
prompt
==
"prompt"
def
test_multimodal_request_with_audio
(
self
,
serving_chat
):
"""Test request with audio content"""
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"text"
,
"text"
:
"Transcribe this audio"
},
{
"type"
:
"audio_url"
,
"audio_url"
:
{
"url"
:
"data:audio/wav;base64,UklGR..."
},
},
],
}
],
)
)
self
.
tm
.
chat_template_name
=
None
serving_chat
.
tokenizer_manager
.
model_config
.
is_multimodal
=
True
self
.
tm
.
tokenizer
.
chat_template
=
"<jinja>"
with
patch
.
object
(
serving_chat
,
"_apply_jinja_template"
)
as
mock_apply
:
with
patch
.
object
(
mock_apply
.
return_value
=
(
self
.
chat
,
"prompt"
,
"_apply_jinja_template"
,
[
1
,
2
,
3
],
return_value
=
(
"processed"
,
[
1
],
None
,
None
,
[],
[
"</s>"
]),
None
,
),
patch
(
"builtins.hasattr"
,
return_value
=
True
):
[
"audio_data"
],
prompt
,
prompt_ids
,
*
_
=
self
.
chat
.
_process_messages
(
req
,
False
)
[
"audio"
],
self
.
assertEqual
(
prompt
,
"processed"
)
[],
self
.
assertEqual
(
prompt_ids
,
[
1
])
)
# ------------- sampling-params -------------
with
patch
.
object
(
def
test_sampling_param_build
(
self
):
serving_chat
,
"_apply_conversation_template"
req
=
ChatCompletionRequest
(
)
as
mock_conv
:
model
=
"x"
,
mock_conv
.
return_value
=
(
"prompt"
,
None
,
[
"audio_data"
],
[
"audio"
],
[])
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hi"
}],
(
prompt
,
prompt_ids
,
image_data
,
audio_data
,
modalities
,
stop
,
tool_call_constraint
,
)
=
serving_chat
.
_process_messages
(
request
,
True
)
assert
audio_data
==
[
"audio_data"
]
assert
modalities
==
[
"audio"
]
class
TestTemplateHandling
:
"""Test chat template handling from adapter.py"""
def
test_jinja_template_processing
(
self
,
serving_chat
):
"""Test Jinja template processing"""
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello"
}]
)
# Mock the template attribute directly
serving_chat
.
tokenizer_manager
.
chat_template_name
=
None
serving_chat
.
tokenizer_manager
.
tokenizer
.
chat_template
=
"<jinja_template>"
with
patch
.
object
(
serving_chat
,
"_apply_jinja_template"
)
as
mock_apply
:
mock_apply
.
return_value
=
(
"processed_prompt"
,
[
1
,
2
,
3
],
None
,
None
,
[],
[
"</s>"
],
)
# Mock hasattr to simulate the None check
with
patch
(
"builtins.hasattr"
)
as
mock_hasattr
:
mock_hasattr
.
return_value
=
True
(
prompt
,
prompt_ids
,
image_data
,
audio_data
,
modalities
,
stop
,
tool_call_constraint
,
)
=
serving_chat
.
_process_messages
(
request
,
False
)
assert
prompt
==
"processed_prompt"
assert
prompt_ids
==
[
1
,
2
,
3
]
def
test_conversation_template_processing
(
self
,
serving_chat
):
"""Test conversation template processing"""
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello"
}]
)
serving_chat
.
tokenizer_manager
.
chat_template_name
=
"llama-3"
with
patch
.
object
(
serving_chat
,
"_apply_conversation_template"
)
as
mock_apply
:
mock_apply
.
return_value
=
(
"conv_prompt"
,
None
,
None
,
[],
[
"</s>"
])
(
prompt
,
prompt_ids
,
image_data
,
audio_data
,
modalities
,
stop
,
tool_call_constraint
,
)
=
serving_chat
.
_process_messages
(
request
,
False
)
assert
prompt
==
"conv_prompt"
assert
stop
==
[
"</s>"
]
def
test_continue_final_message
(
self
,
serving_chat
):
"""Test continue_final_message functionality"""
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[
{
"role"
:
"user"
,
"content"
:
"Hello"
},
{
"role"
:
"assistant"
,
"content"
:
"Hi there"
},
],
continue_final_message
=
True
,
)
with
patch
.
object
(
serving_chat
,
"_apply_conversation_template"
)
as
mock_apply
:
mock_apply
.
return_value
=
(
"Hi there"
,
None
,
None
,
[],
[
"</s>"
])
(
prompt
,
prompt_ids
,
image_data
,
audio_data
,
modalities
,
stop
,
tool_call_constraint
,
)
=
serving_chat
.
_process_messages
(
request
,
False
)
# Should handle continue_final_message properly
assert
prompt
==
"Hi there"
class
TestReasoningContent
:
"""Test reasoning content separation from adapter.py"""
def
test_reasoning_content_request
(
self
,
serving_chat
):
"""Test request with reasoning content separation"""
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Solve this math problem"
}],
separate_reasoning
=
True
,
stream_reasoning
=
False
,
)
with
patch
.
object
(
serving_chat
,
"_process_messages"
)
as
mock_process
:
mock_process
.
return_value
=
(
"Test prompt"
,
[
1
,
2
,
3
],
None
,
None
,
[],
[
"</s>"
],
None
,
# tool_call_constraint
)
adapted_request
,
_
=
serving_chat
.
_convert_to_internal_request
(
[
request
],
[
"test-id"
]
)
assert
adapted_request
.
rid
==
"test-id"
assert
request
.
separate_reasoning
==
True
def
test_reasoning_content_response
(
self
,
serving_chat
):
"""Test reasoning content in response"""
mock_ret_item
=
{
"text"
:
"<thinking>This is reasoning</thinking>Answer: 42"
,
"meta_info"
:
{
"output_token_logprobs"
:
[],
"output_top_logprobs"
:
None
,
},
}
# Mock ReasoningParser
with
patch
(
"sglang.srt.entrypoints.openai.serving_chat.ReasoningParser"
)
as
mock_parser_class
:
mock_parser
=
Mock
()
mock_parser
.
parse_non_stream
.
return_value
=
(
"This is reasoning"
,
"Answer: 42"
,
)
mock_parser_class
.
return_value
=
mock_parser
choice_logprobs
=
None
reasoning_text
=
None
text
=
mock_ret_item
[
"text"
]
# Simulate reasoning processing
enable_thinking
=
True
if
enable_thinking
:
parser
=
mock_parser_class
(
model_type
=
"test"
,
stream_reasoning
=
False
)
reasoning_text
,
text
=
parser
.
parse_non_stream
(
text
)
assert
reasoning_text
==
"This is reasoning"
assert
text
==
"Answer: 42"
class
TestSamplingParams
:
"""Test sampling parameter handling from adapter.py"""
def
test_all_sampling_parameters
(
self
,
serving_chat
):
"""Test all sampling parameters are properly handled"""
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Hello"
}],
temperature
=
0.8
,
temperature
=
0.8
,
max_tokens
=
150
,
max_tokens
=
150
,
max_completion_tokens
=
200
,
min_tokens
=
5
,
min_tokens
=
5
,
top_p
=
0.9
,
top_p
=
0.9
,
top_k
=
50
,
stop
=
[
"</s>"
],
min_p
=
0.1
,
presence_penalty
=
0.1
,
frequency_penalty
=
0.2
,
repetition_penalty
=
1.1
,
stop
=
[
"<|endoftext|>"
],
stop_token_ids
=
[
13
,
14
],
regex
=
r
"\d+"
,
ebnf
=
"<expr> ::= <number>"
,
n
=
2
,
no_stop_trim
=
True
,
ignore_eos
=
True
,
skip_special_tokens
=
False
,
logit_bias
=
{
"1"
:
0.5
,
"2"
:
-
0.3
},
)
with
patch
.
object
(
serving_chat
,
"_process_messages"
)
as
mock_process
:
mock_process
.
return_value
=
(
"Test prompt"
,
[
1
,
2
,
3
],
None
,
None
,
[],
[
"</s>"
],
None
,
# tool_call_constraint
)
sampling_params
=
serving_chat
.
_build_sampling_params
(
request
,
[
"</s>"
],
None
)
# Verify all parameters
assert
sampling_params
[
"temperature"
]
==
0.8
assert
sampling_params
[
"max_new_tokens"
]
==
150
assert
sampling_params
[
"min_new_tokens"
]
==
5
assert
sampling_params
[
"top_p"
]
==
0.9
assert
sampling_params
[
"top_k"
]
==
50
assert
sampling_params
[
"min_p"
]
==
0.1
assert
sampling_params
[
"presence_penalty"
]
==
0.1
assert
sampling_params
[
"frequency_penalty"
]
==
0.2
assert
sampling_params
[
"repetition_penalty"
]
==
1.1
assert
sampling_params
[
"stop"
]
==
[
"</s>"
]
assert
sampling_params
[
"logit_bias"
]
==
{
"1"
:
0.5
,
"2"
:
-
0.3
}
def
test_response_format_json_schema
(
self
,
serving_chat
):
"""Test response format with JSON schema"""
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Generate JSON"
}],
response_format
=
{
"type"
:
"json_schema"
,
"json_schema"
:
{
"name"
:
"response"
,
"schema"
:
{
"type"
:
"object"
,
"properties"
:
{
"answer"
:
{
"type"
:
"string"
}},
},
},
},
)
with
patch
.
object
(
serving_chat
,
"_process_messages"
)
as
mock_process
:
mock_process
.
return_value
=
(
"Test prompt"
,
[
1
,
2
,
3
],
None
,
None
,
[],
[
"</s>"
],
None
,
# tool_call_constraint
)
sampling_params
=
serving_chat
.
_build_sampling_params
(
request
,
[
"</s>"
],
None
)
assert
"json_schema"
in
sampling_params
assert
'"type": "object"'
in
sampling_params
[
"json_schema"
]
def
test_response_format_json_object
(
self
,
serving_chat
):
"""Test response format with JSON object"""
request
=
ChatCompletionRequest
(
model
=
"test-model"
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Generate JSON"
}],
response_format
=
{
"type"
:
"json_object"
},
)
)
with
patch
.
object
(
with
patch
.
object
(
serving_chat
,
"_process_messages"
)
as
mock_process
:
self
.
chat
,
mock_process
.
return_value
=
(
"_process_messages"
,
"Test prompt"
,
return_value
=
(
"Prompt"
,
[
1
],
None
,
None
,
[],
[
"</s>"
],
None
),
[
1
,
2
,
3
],
):
None
,
params
=
self
.
chat
.
_build_sampling_params
(
req
,
[
"</s>"
],
None
)
None
,
self
.
assertEqual
(
params
[
"temperature"
],
0.8
)
[],
self
.
assertEqual
(
params
[
"max_new_tokens"
],
150
)
[
"</s>"
],
self
.
assertEqual
(
params
[
"min_new_tokens"
],
5
)
None
,
# tool_call_constraint
self
.
assertEqual
(
params
[
"stop"
],
[
"</s>"
])
)
sampling_params
=
serving_chat
.
_build_sampling_params
(
if
__name__
==
"__main__"
:
request
,
[
"</s>"
],
None
unittest
.
main
(
verbosity
=
2
)
)
assert
sampling_params
[
"json_schema"
]
==
'{"type": "object"}'
test/srt/openai/test_serving_completions.py
View file @
ffd1a26e
"""
"""
Tests for the refactored completions serving handler
Unit-tests for the refactored completions-serving handler (no pytest).
Run with:
python -m unittest tests.test_serving_completions_unit -v
"""
"""
import
unittest
from
unittest.mock
import
AsyncMock
,
Mock
,
patch
from
unittest.mock
import
AsyncMock
,
Mock
,
patch
import
pytest
from
sglang.srt.entrypoints.openai.protocol
import
CompletionRequest
from
sglang.srt.entrypoints.openai.protocol
import
(
CompletionRequest
,
CompletionResponse
,
CompletionResponseChoice
,
CompletionStreamResponse
,
ErrorResponse
,
)
from
sglang.srt.entrypoints.openai.serving_completions
import
OpenAIServingCompletion
from
sglang.srt.entrypoints.openai.serving_completions
import
OpenAIServingCompletion
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
@
pytest
.
fixture
class
ServingCompletionTestCase
(
unittest
.
TestCase
):
def
mock_tokenizer_manager
():
"""Bundle all prompt/echo tests in one TestCase."""
"""Create a mock tokenizer manager"""
manager
=
Mock
(
spec
=
TokenizerManager
)
# Mock tokenizer
manager
.
tokenizer
=
Mock
()
manager
.
tokenizer
.
encode
=
Mock
(
return_value
=
[
1
,
2
,
3
,
4
])
manager
.
tokenizer
.
decode
=
Mock
(
return_value
=
"decoded text"
)
manager
.
tokenizer
.
bos_token_id
=
1
# Mock model config
manager
.
model_config
=
Mock
()
manager
.
model_config
.
is_multimodal
=
False
# Mock server args
manager
.
server_args
=
Mock
()
manager
.
server_args
.
enable_cache_report
=
False
# Mock generation
# ---------- shared test fixtures ----------
manager
.
generate_request
=
AsyncMock
()
def
setUp
(
self
):
manager
.
create_abort_task
=
Mock
(
return_value
=
None
)
# build the mock TokenizerManager once for every test
tm
=
Mock
(
spec
=
TokenizerManager
)
return
manager
tm
.
tokenizer
=
Mock
()
tm
.
tokenizer
.
encode
.
return_value
=
[
1
,
2
,
3
,
4
]
tm
.
tokenizer
.
decode
.
return_value
=
"decoded text"
tm
.
tokenizer
.
bos_token_id
=
1
tm
.
model_config
=
Mock
(
is_multimodal
=
False
)
tm
.
server_args
=
Mock
(
enable_cache_report
=
False
)
@
pytest
.
fixture
tm
.
generate_request
=
AsyncMock
()
def
serving_completion
(
mock_tokenizer_manager
):
tm
.
create_abort_task
=
Mock
()
"""Create a OpenAIServingCompletion instance"""
return
OpenAIServingCompletion
(
mock_tokenizer_manager
)
self
.
sc
=
OpenAIServingCompletion
(
tm
)
class
TestPromptHandling
:
# ---------- prompt-handling ----------
"""Test different prompt types and formats from adapter.py"""
def
test_single_string_prompt
(
self
):
req
=
CompletionRequest
(
model
=
"x"
,
prompt
=
"Hello world"
,
max_tokens
=
100
)
def
test_single_string_prompt
(
self
,
serving_completion
):
internal
,
_
=
self
.
sc
.
_convert_to_internal_request
([
req
],
[
"id"
])
"""Test handling single string prompt"""
self
.
assertEqual
(
internal
.
text
,
"Hello world"
)
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
"Hello world"
,
max_tokens
=
100
)
adapted_request
,
_
=
serving_completion
.
_convert_to_internal_request
(
[
request
],
[
"test-id"
]
)
assert
adapted_request
.
text
==
"Hello world"
def
test_single_token_ids_prompt
(
self
,
serving_completion
):
"""Test handling single token IDs prompt"""
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
[
1
,
2
,
3
,
4
],
max_tokens
=
100
)
adapted_request
,
_
=
serving_completion
.
_convert_to_internal_request
(
[
request
],
[
"test-id"
]
)
assert
adapted_request
.
input_ids
==
[
1
,
2
,
3
,
4
]
def
test_single_token_ids_prompt
(
self
):
req
=
CompletionRequest
(
model
=
"x"
,
prompt
=
[
1
,
2
,
3
,
4
],
max_tokens
=
100
)
internal
,
_
=
self
.
sc
.
_convert_to_internal_request
([
req
],
[
"id"
])
self
.
assertEqual
(
internal
.
input_ids
,
[
1
,
2
,
3
,
4
])
def
test_completion_template_handling
(
self
,
serving_completion
):
def
test_completion_template_handling
(
self
):
"""Test completion template processing"""
req
=
CompletionRequest
(
request
=
CompletionRequest
(
model
=
"x"
,
prompt
=
"def f():"
,
suffix
=
"return 1"
,
max_tokens
=
100
model
=
"test-model"
,
prompt
=
"def hello():"
,
suffix
=
"return 'world'"
,
max_tokens
=
100
,
)
)
with
patch
(
with
patch
(
"sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined"
,
"sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined"
,
return_value
=
True
,
return_value
=
True
,
),
patch
(
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request"
,
return_value
=
"processed_prompt"
,
):
):
with
patch
(
internal
,
_
=
self
.
sc
.
_convert_to_internal_request
([
req
],
[
"id"
])
"sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request"
,
self
.
assertEqual
(
internal
.
text
,
"processed_prompt"
)
return_value
=
"processed_prompt"
,
):
adapted_request
,
_
=
serving_completion
.
_convert_to_internal_request
(
[
request
],
[
"test-id"
]
)
assert
adapted_request
.
text
==
"processed_prompt"
# ---------- echo-handling ----------
def
test_echo_with_string_prompt_streaming
(
self
):
req
=
CompletionRequest
(
model
=
"x"
,
prompt
=
"Hello"
,
max_tokens
=
1
,
echo
=
True
)
self
.
assertEqual
(
self
.
sc
.
_get_echo_text
(
req
,
0
),
"Hello"
)
class
TestEchoHandling
:
def
test_echo_with_list_of_strings_streaming
(
self
):
"""Test echo functionality from adapter.py"""
req
=
CompletionRequest
(
model
=
"x"
,
prompt
=
[
"A"
,
"B"
],
max_tokens
=
1
,
echo
=
True
,
n
=
1
def
test_echo_with_string_prompt_streaming
(
self
,
serving_completion
):
"""Test echo handling with string prompt in streaming"""
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
"Hello"
,
max_tokens
=
100
,
echo
=
True
)
)
self
.
assertEqual
(
self
.
sc
.
_get_echo_text
(
req
,
0
),
"A"
)
self
.
assertEqual
(
self
.
sc
.
_get_echo_text
(
req
,
1
),
"B"
)
# Test _get_echo_text method
def
test_echo_with_token_ids_streaming
(
self
):
echo_text
=
serving_completion
.
_get_echo_text
(
request
,
0
)
req
=
CompletionRequest
(
model
=
"x"
,
prompt
=
[
1
,
2
,
3
],
max_tokens
=
1
,
echo
=
True
)
assert
echo_text
==
"Hello"
self
.
sc
.
tokenizer_manager
.
tokenizer
.
decode
.
return_value
=
"decoded_prompt"
self
.
assertEqual
(
self
.
sc
.
_get_echo_text
(
req
,
0
),
"decoded_prompt"
)
def
test_echo_with_list_of_strings_streaming
(
self
,
serving_completion
):
"""Test echo handling with list of strings in streaming"""
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
[
"Hello"
,
"World"
],
max_tokens
=
100
,
echo
=
True
,
n
=
1
,
)
echo_text
=
serving_completion
.
_get_echo_text
(
request
,
0
)
def
test_echo_with_multiple_token_ids_streaming
(
self
):
assert
echo_text
==
"Hello"
req
=
CompletionRequest
(
model
=
"x"
,
prompt
=
[[
1
,
2
],
[
3
,
4
]],
max_tokens
=
1
,
echo
=
True
,
n
=
1
echo_text
=
serving_completion
.
_get_echo_text
(
request
,
1
)
assert
echo_text
==
"World"
def
test_echo_with_token_ids_streaming
(
self
,
serving_completion
):
"""Test echo handling with token IDs in streaming"""
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
[
1
,
2
,
3
],
max_tokens
=
100
,
echo
=
True
)
)
self
.
sc
.
tokenizer_manager
.
tokenizer
.
decode
.
return_value
=
"decoded"
self
.
assertEqual
(
self
.
sc
.
_get_echo_text
(
req
,
0
),
"decoded"
)
serving_completion
.
tokenizer_manager
.
tokenizer
.
decode
.
return_value
=
(
def
test_prepare_echo_prompts_non_streaming
(
self
):
"decoded_prompt"
# single string
)
req
=
CompletionRequest
(
model
=
"x"
,
prompt
=
"Hi"
,
echo
=
True
)
echo_text
=
serving_completion
.
_get_echo_text
(
request
,
0
)
self
.
assertEqual
(
self
.
sc
.
_prepare_echo_prompts
(
req
),
[
"Hi"
])
assert
echo_text
==
"decoded_prompt"
def
test_echo_with_multiple_token_ids_streaming
(
self
,
serving_completion
):
# list of strings
"""Test echo handling with multiple token ID prompts in streaming"""
req
=
CompletionRequest
(
model
=
"x"
,
prompt
=
[
"Hi"
,
"Yo"
],
echo
=
True
)
request
=
CompletionRequest
(
self
.
assertEqual
(
self
.
sc
.
_prepare_echo_prompts
(
req
),
[
"Hi"
,
"Yo"
])
model
=
"test-model"
,
prompt
=
[[
1
,
2
],
[
3
,
4
]],
max_tokens
=
100
,
echo
=
True
,
n
=
1
)
serving_completion
.
tokenizer_manager
.
tokenizer
.
decode
.
return_value
=
"decoded"
echo_text
=
serving_completion
.
_get_echo_text
(
request
,
0
)
assert
echo_text
==
"decoded"
def
test_prepare_echo_prompts_non_streaming
(
self
,
serving_completion
):
"""Test prepare echo prompts for non-streaming response"""
# Test with single string
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
"Hello"
,
echo
=
True
)
echo_prompts
=
serving_completion
.
_prepare_echo_prompts
(
request
)
assert
echo_prompts
==
[
"Hello"
]
# Test with list of strings
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
[
"Hello"
,
"World"
],
echo
=
True
)
echo_prompts
=
serving_completion
.
_prepare_echo_prompts
(
request
)
# token IDs
assert
echo_prompts
==
[
"Hello"
,
"World"
]
req
=
CompletionRequest
(
model
=
"x"
,
prompt
=
[
1
,
2
,
3
],
echo
=
True
)
self
.
sc
.
tokenizer_manager
.
tokenizer
.
decode
.
return_value
=
"decoded"
self
.
assertEqual
(
self
.
sc
.
_prepare_echo_prompts
(
req
),
[
"decoded"
])
# Test with token IDs
request
=
CompletionRequest
(
model
=
"test-model"
,
prompt
=
[
1
,
2
,
3
],
echo
=
True
)
serving_completion
.
tokenizer_manager
.
tokenizer
.
decode
.
return_value
=
"decoded"
if
__name__
==
"__main__"
:
echo_prompts
=
serving_completion
.
_prepare_echo_prompts
(
request
)
unittest
.
main
(
verbosity
=
2
)
assert
echo_prompts
==
[
"decoded"
]
test/srt/openai/test_serving_embedding.py
View file @
ffd1a26e
...
@@ -8,11 +8,11 @@ with the original adapter.py functionality and follows OpenAI API specifications
...
@@ -8,11 +8,11 @@ with the original adapter.py functionality and follows OpenAI API specifications
import
asyncio
import
asyncio
import
json
import
json
import
time
import
time
import
unittest
import
uuid
import
uuid
from
typing
import
Any
,
Dict
,
List
from
typing
import
Any
,
Dict
,
List
from
unittest.mock
import
AsyncMock
,
Mock
,
patch
from
unittest.mock
import
AsyncMock
,
Mock
,
patch
import
pytest
from
fastapi
import
Request
from
fastapi
import
Request
from
fastapi.responses
import
ORJSONResponse
from
fastapi.responses
import
ORJSONResponse
from
pydantic_core
import
ValidationError
from
pydantic_core
import
ValidationError
...
@@ -30,7 +30,7 @@ from sglang.srt.managers.io_struct import EmbeddingReqInput
...
@@ -30,7 +30,7 @@ from sglang.srt.managers.io_struct import EmbeddingReqInput
# Mock TokenizerManager for embedding tests
# Mock TokenizerManager for embedding tests
class
MockTokenizerManager
:
class
_
MockTokenizerManager
:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
model_config
=
Mock
()
self
.
model_config
=
Mock
()
self
.
model_config
.
is_multimodal
=
False
self
.
model_config
.
is_multimodal
=
False
...
@@ -58,141 +58,98 @@ class MockTokenizerManager:
...
@@ -58,141 +58,98 @@ class MockTokenizerManager:
self
.
generate_request
=
Mock
(
return_value
=
mock_generate_embedding
())
self
.
generate_request
=
Mock
(
return_value
=
mock_generate_embedding
())
@
pytest
.
fixture
class
ServingEmbeddingTestCase
(
unittest
.
TestCase
):
def
mock_tokenizer_manager
():
def
setUp
(
self
):
"""Create a mock tokenizer manager for testing."""
"""Set up test fixtures."""
return
MockTokenizerManager
()
self
.
tokenizer_manager
=
_MockTokenizerManager
()
self
.
serving_embedding
=
OpenAIServingEmbedding
(
self
.
tokenizer_manager
)
self
.
request
=
Mock
(
spec
=
Request
)
self
.
request
.
headers
=
{}
@
pytest
.
fixture
self
.
basic_req
=
EmbeddingRequest
(
def
serving_embedding
(
mock_tokenizer_manager
):
model
=
"test-model"
,
"""Create an OpenAIServingEmbedding instance for testing."""
input
=
"Hello, how are you?"
,
return
OpenAIServingEmbedding
(
mock_tokenizer_manager
)
encoding_format
=
"float"
,
)
self
.
list_req
=
EmbeddingRequest
(
@
pytest
.
fixture
model
=
"test-model"
,
def
mock_request
():
input
=
[
"Hello, how are you?"
,
"I am fine, thank you!"
],
"""Create a mock FastAPI request."""
encoding_format
=
"float"
,
request
=
Mock
(
spec
=
Request
)
)
request
.
headers
=
{}
self
.
multimodal_req
=
EmbeddingRequest
(
return
request
model
=
"test-model"
,
input
=
[
MultimodalEmbeddingInput
(
text
=
"Hello"
,
image
=
"base64_image_data"
),
@
pytest
.
fixture
MultimodalEmbeddingInput
(
text
=
"World"
,
image
=
None
),
def
basic_embedding_request
():
],
"""Create a basic embedding request."""
encoding_format
=
"float"
,
return
EmbeddingRequest
(
)
model
=
"test-model"
,
self
.
token_ids_req
=
EmbeddingRequest
(
input
=
"Hello, how are you?"
,
model
=
"test-model"
,
encoding_format
=
"float"
,
input
=
[
1
,
2
,
3
,
4
,
5
],
)
encoding_format
=
"float"
,
)
@
pytest
.
fixture
def
list_embedding_request
():
"""Create an embedding request with list input."""
return
EmbeddingRequest
(
model
=
"test-model"
,
input
=
[
"Hello, how are you?"
,
"I am fine, thank you!"
],
encoding_format
=
"float"
,
)
@
pytest
.
fixture
def
multimodal_embedding_request
():
"""Create a multimodal embedding request."""
return
EmbeddingRequest
(
model
=
"test-model"
,
input
=
[
MultimodalEmbeddingInput
(
text
=
"Hello"
,
image
=
"base64_image_data"
),
MultimodalEmbeddingInput
(
text
=
"World"
,
image
=
None
),
],
encoding_format
=
"float"
,
)
@
pytest
.
fixture
def
token_ids_embedding_request
():
"""Create an embedding request with token IDs."""
return
EmbeddingRequest
(
model
=
"test-model"
,
input
=
[
1
,
2
,
3
,
4
,
5
],
encoding_format
=
"float"
,
)
class
TestOpenAIServingEmbeddingConversion
:
"""Test request conversion methods."""
def
test_convert_single_string_request
(
def
test_convert_single_string_request
(
self
):
self
,
serving_embedding
,
basic_embedding_request
):
"""Test converting single string request to internal format."""
"""Test converting single string request to internal format."""
adapted_request
,
processed_request
=
(
adapted_request
,
processed_request
=
(
serving_embedding
.
_convert_to_internal_request
(
self
.
serving_embedding
.
_convert_to_internal_request
(
[
basic_
embedding_request
],
[
"test-id"
]
[
self
.
basic_
req
],
[
"test-id"
]
)
)
)
)
assert
isi
nstance
(
adapted_request
,
EmbeddingReqInput
)
self
.
assert
IsI
nstance
(
adapted_request
,
EmbeddingReqInput
)
assert
adapted_request
.
text
==
"Hello, how are you?"
self
.
assert
Equal
(
adapted_request
.
text
,
"Hello, how are you?"
)
assert
adapted_request
.
rid
==
"test-id"
self
.
assert
Equal
(
adapted_request
.
rid
,
"test-id"
)
assert
processed_request
==
basic_embedding_request
self
.
assert
Equal
(
processed_request
,
self
.
basic_req
)
def
test_convert_list_string_request
(
def
test_convert_list_string_request
(
self
):
self
,
serving_embedding
,
list_embedding_request
):
"""Test converting list of strings request to internal format."""
"""Test converting list of strings request to internal format."""
adapted_request
,
processed_request
=
(
adapted_request
,
processed_request
=
(
serving_embedding
.
_convert_to_internal_request
(
self
.
serving_embedding
.
_convert_to_internal_request
(
[
list_embedding_request
],
[
"test-id"
]
[
self
.
list_req
],
[
"test-id"
]
)
)
)
)
assert
isinstance
(
adapted_request
,
EmbeddingReqInput
)
self
.
assertIsInstance
(
adapted_request
,
EmbeddingReqInput
)
assert
adapted_request
.
text
==
[
"Hello, how are you?"
,
"I am fine, thank you!"
]
self
.
assertEqual
(
assert
adapted_request
.
rid
==
"test-id"
adapted_request
.
text
,
[
"Hello, how are you?"
,
"I am fine, thank you!"
]
assert
processed_request
==
list_embedding_request
)
self
.
assertEqual
(
adapted_request
.
rid
,
"test-id"
)
self
.
assertEqual
(
processed_request
,
self
.
list_req
)
def
test_convert_token_ids_request
(
def
test_convert_token_ids_request
(
self
):
self
,
serving_embedding
,
token_ids_embedding_request
):
"""Test converting token IDs request to internal format."""
"""Test converting token IDs request to internal format."""
adapted_request
,
processed_request
=
(
adapted_request
,
processed_request
=
(
serving_embedding
.
_convert_to_internal_request
(
self
.
serving_embedding
.
_convert_to_internal_request
(
[
token_ids_
embedding_request
],
[
"test-id"
]
[
self
.
token_ids_
req
],
[
"test-id"
]
)
)
)
)
assert
isi
nstance
(
adapted_request
,
EmbeddingReqInput
)
self
.
assert
IsI
nstance
(
adapted_request
,
EmbeddingReqInput
)
assert
adapted_request
.
input_ids
==
[
1
,
2
,
3
,
4
,
5
]
self
.
assert
Equal
(
adapted_request
.
input_ids
,
[
1
,
2
,
3
,
4
,
5
]
)
assert
adapted_request
.
rid
==
"test-id"
self
.
assert
Equal
(
adapted_request
.
rid
,
"test-id"
)
assert
processed_request
==
token_ids_
embedding_request
self
.
assert
Equal
(
processed_request
,
self
.
token_ids_
req
)
def
test_convert_multimodal_request
(
def
test_convert_multimodal_request
(
self
):
self
,
serving_embedding
,
multimodal_embedding_request
):
"""Test converting multimodal request to internal format."""
"""Test converting multimodal request to internal format."""
adapted_request
,
processed_request
=
(
adapted_request
,
processed_request
=
(
serving_embedding
.
_convert_to_internal_request
(
self
.
serving_embedding
.
_convert_to_internal_request
(
[
multimodal_
embedding_request
],
[
"test-id"
]
[
self
.
multimodal_
req
],
[
"test-id"
]
)
)
)
)
assert
isi
nstance
(
adapted_request
,
EmbeddingReqInput
)
self
.
assert
IsI
nstance
(
adapted_request
,
EmbeddingReqInput
)
# Should extract text and images separately
# Should extract text and images separately
assert
len
(
adapted_request
.
text
)
==
2
self
.
assertEqual
(
len
(
adapted_request
.
text
),
2
)
assert
"Hello"
in
adapted_request
.
text
self
.
assertIn
(
"Hello"
,
adapted_request
.
text
)
assert
"World"
in
adapted_request
.
text
self
.
assertIn
(
"World"
,
adapted_request
.
text
)
assert
adapted_request
.
image_data
[
0
]
==
"base64_image_data"
self
.
assertEqual
(
adapted_request
.
image_data
[
0
],
"base64_image_data"
)
assert
adapted_request
.
image_data
[
1
]
is
None
self
.
assertIsNone
(
adapted_request
.
image_data
[
1
])
assert
adapted_request
.
rid
==
"test-id"
self
.
assertEqual
(
adapted_request
.
rid
,
"test-id"
)
def
test_build_single_embedding_response
(
self
):
class
TestEmbeddingResponseBuilding
:
"""Test response building methods."""
def
test_build_single_embedding_response
(
self
,
serving_embedding
):
"""Test building response for single embedding."""
"""Test building response for single embedding."""
ret_data
=
[
ret_data
=
[
{
{
...
@@ -201,19 +158,21 @@ class TestEmbeddingResponseBuilding:
...
@@ -201,19 +158,21 @@ class TestEmbeddingResponseBuilding:
}
}
]
]
response
=
serving_embedding
.
_build_embedding_response
(
ret_data
,
"test-model"
)
response
=
self
.
serving_embedding
.
_build_embedding_response
(
ret_data
,
"test-model"
assert
isinstance
(
response
,
EmbeddingResponse
)
)
assert
response
.
model
==
"test-model"
assert
len
(
response
.
data
)
==
1
assert
response
.
data
[
0
].
embedding
==
[
0.1
,
0.2
,
0.3
,
0.4
,
0.5
]
assert
response
.
data
[
0
].
index
==
0
assert
response
.
data
[
0
].
object
==
"embedding"
assert
response
.
usage
.
prompt_tokens
==
5
assert
response
.
usage
.
total_tokens
==
5
assert
response
.
usage
.
completion_tokens
==
0
def
test_build_multiple_embedding_response
(
self
,
serving_embedding
):
self
.
assertIsInstance
(
response
,
EmbeddingResponse
)
self
.
assertEqual
(
response
.
model
,
"test-model"
)
self
.
assertEqual
(
len
(
response
.
data
),
1
)
self
.
assertEqual
(
response
.
data
[
0
].
embedding
,
[
0.1
,
0.2
,
0.3
,
0.4
,
0.5
])
self
.
assertEqual
(
response
.
data
[
0
].
index
,
0
)
self
.
assertEqual
(
response
.
data
[
0
].
object
,
"embedding"
)
self
.
assertEqual
(
response
.
usage
.
prompt_tokens
,
5
)
self
.
assertEqual
(
response
.
usage
.
total_tokens
,
5
)
self
.
assertEqual
(
response
.
usage
.
completion_tokens
,
0
)
def
test_build_multiple_embedding_response
(
self
):
"""Test building response for multiple embeddings."""
"""Test building response for multiple embeddings."""
ret_data
=
[
ret_data
=
[
{
{
...
@@ -226,25 +185,20 @@ class TestEmbeddingResponseBuilding:
...
@@ -226,25 +185,20 @@ class TestEmbeddingResponseBuilding:
},
},
]
]
response
=
serving_embedding
.
_build_embedding_response
(
ret_data
,
"test-model"
)
response
=
self
.
serving_embedding
.
_build_embedding_response
(
ret_data
,
"test-model"
assert
isinstance
(
response
,
EmbeddingResponse
)
)
assert
len
(
response
.
data
)
==
2
assert
response
.
data
[
0
].
embedding
==
[
0.1
,
0.2
,
0.3
]
assert
response
.
data
[
0
].
index
==
0
assert
response
.
data
[
1
].
embedding
==
[
0.4
,
0.5
,
0.6
]
assert
response
.
data
[
1
].
index
==
1
assert
response
.
usage
.
prompt_tokens
==
7
# 3 + 4
assert
response
.
usage
.
total_tokens
==
7
@
pytest
.
mark
.
asyncio
self
.
assertIsInstance
(
response
,
EmbeddingResponse
)
class
TestOpenAIServingEmbeddingAsyncMethods
:
self
.
assertEqual
(
len
(
response
.
data
),
2
)
"""Test async methods of OpenAIServingEmbedding."""
self
.
assertEqual
(
response
.
data
[
0
].
embedding
,
[
0.1
,
0.2
,
0.3
])
self
.
assertEqual
(
response
.
data
[
0
].
index
,
0
)
self
.
assertEqual
(
response
.
data
[
1
].
embedding
,
[
0.4
,
0.5
,
0.6
])
self
.
assertEqual
(
response
.
data
[
1
].
index
,
1
)
self
.
assertEqual
(
response
.
usage
.
prompt_tokens
,
7
)
# 3 + 4
self
.
assertEqual
(
response
.
usage
.
total_tokens
,
7
)
async
def
test_handle_request_success
(
async
def
test_handle_request_success
(
self
):
self
,
serving_embedding
,
basic_embedding_request
,
mock_request
):
"""Test successful embedding request handling."""
"""Test successful embedding request handling."""
# Mock the generate_request to return expected data
# Mock the generate_request to return expected data
...
@@ -254,32 +208,30 @@ class TestOpenAIServingEmbeddingAsyncMethods:
...
@@ -254,32 +208,30 @@ class TestOpenAIServingEmbeddingAsyncMethods:
"meta_info"
:
{
"prompt_tokens"
:
5
},
"meta_info"
:
{
"prompt_tokens"
:
5
},
}
}
serving_embedding
.
tokenizer_manager
.
generate_request
=
Mock
(
self
.
serving_embedding
.
tokenizer_manager
.
generate_request
=
Mock
(
return_value
=
mock_generate
()
return_value
=
mock_generate
()
)
)
response
=
await
serving_embedding
.
handle_request
(
response
=
await
self
.
serving_embedding
.
handle_request
(
basic_embedding_request
,
mock_
request
self
.
basic_req
,
self
.
request
)
)
assert
isi
nstance
(
response
,
EmbeddingResponse
)
self
.
assert
IsI
nstance
(
response
,
EmbeddingResponse
)
assert
len
(
response
.
data
)
==
1
self
.
assert
Equal
(
len
(
response
.
data
)
,
1
)
assert
response
.
data
[
0
].
embedding
==
[
0.1
,
0.2
,
0.3
,
0.4
,
0.5
]
self
.
assert
Equal
(
response
.
data
[
0
].
embedding
,
[
0.1
,
0.2
,
0.3
,
0.4
,
0.5
]
)
async
def
test_handle_request_validation_error
(
async
def
test_handle_request_validation_error
(
self
):
self
,
serving_embedding
,
mock_request
):
"""Test handling request with validation error."""
"""Test handling request with validation error."""
invalid_request
=
EmbeddingRequest
(
model
=
"test-model"
,
input
=
""
)
invalid_request
=
EmbeddingRequest
(
model
=
"test-model"
,
input
=
""
)
response
=
await
serving_embedding
.
handle_request
(
invalid_request
,
mock_request
)
response
=
await
self
.
serving_embedding
.
handle_request
(
invalid_request
,
self
.
request
)
assert
isi
nstance
(
response
,
ORJSONResponse
)
self
.
assert
IsI
nstance
(
response
,
ORJSONResponse
)
assert
response
.
status_code
==
400
self
.
assert
Equal
(
response
.
status_code
,
400
)
async
def
test_handle_request_generation_error
(
async
def
test_handle_request_generation_error
(
self
):
self
,
serving_embedding
,
basic_embedding_request
,
mock_request
):
"""Test handling request with generation error."""
"""Test handling request with generation error."""
# Mock generate_request to raise an error
# Mock generate_request to raise an error
...
@@ -287,30 +239,32 @@ class TestOpenAIServingEmbeddingAsyncMethods:
...
@@ -287,30 +239,32 @@ class TestOpenAIServingEmbeddingAsyncMethods:
raise
ValueError
(
"Generation failed"
)
raise
ValueError
(
"Generation failed"
)
yield
# This won't be reached but needed for async generator
yield
# This won't be reached but needed for async generator
serving_embedding
.
tokenizer_manager
.
generate_request
=
Mock
(
self
.
serving_embedding
.
tokenizer_manager
.
generate_request
=
Mock
(
return_value
=
mock_generate_error
()
return_value
=
mock_generate_error
()
)
)
response
=
await
serving_embedding
.
handle_request
(
response
=
await
self
.
serving_embedding
.
handle_request
(
basic_embedding_request
,
mock_
request
self
.
basic_req
,
self
.
request
)
)
assert
isi
nstance
(
response
,
ORJSONResponse
)
self
.
assert
IsI
nstance
(
response
,
ORJSONResponse
)
assert
response
.
status_code
==
400
self
.
assert
Equal
(
response
.
status_code
,
400
)
async
def
test_handle_request_internal_error
(
async
def
test_handle_request_internal_error
(
self
):
self
,
serving_embedding
,
basic_embedding_request
,
mock_request
):
"""Test handling request with internal server error."""
"""Test handling request with internal server error."""
# Mock _convert_to_internal_request to raise an exception
# Mock _convert_to_internal_request to raise an exception
with
patch
.
object
(
with
patch
.
object
(
serving_embedding
,
self
.
serving_embedding
,
"_convert_to_internal_request"
,
"_convert_to_internal_request"
,
side_effect
=
Exception
(
"Internal error"
),
side_effect
=
Exception
(
"Internal error"
),
):
):
response
=
await
serving_embedding
.
handle_request
(
response
=
await
self
.
serving_embedding
.
handle_request
(
basic_embedding_request
,
mock_
request
self
.
basic_req
,
self
.
request
)
)
assert
isinstance
(
response
,
ORJSONResponse
)
self
.
assertIsInstance
(
response
,
ORJSONResponse
)
assert
response
.
status_code
==
500
self
.
assertEqual
(
response
.
status_code
,
500
)
if
__name__
==
"__main__"
:
unittest
.
main
(
verbosity
=
2
)
test/srt/run_suite.py
View file @
ffd1a26e
...
@@ -62,6 +62,11 @@ suites = {
...
@@ -62,6 +62,11 @@ suites = {
TestFile
(
"test_openai_adapter.py"
,
1
),
TestFile
(
"test_openai_adapter.py"
,
1
),
TestFile
(
"test_openai_function_calling.py"
,
60
),
TestFile
(
"test_openai_function_calling.py"
,
60
),
TestFile
(
"test_openai_server.py"
,
149
),
TestFile
(
"test_openai_server.py"
,
149
),
TestFile
(
"openai/test_server.py"
,
120
),
TestFile
(
"openai/test_protocol.py"
,
60
),
TestFile
(
"openai/test_serving_chat.py"
,
120
),
TestFile
(
"openai/test_serving_completions.py"
,
120
),
TestFile
(
"openai/test_serving_embedding.py"
,
120
),
TestFile
(
"test_openai_server_hidden_states.py"
,
240
),
TestFile
(
"test_openai_server_hidden_states.py"
,
240
),
TestFile
(
"test_penalty.py"
,
41
),
TestFile
(
"test_penalty.py"
,
41
),
TestFile
(
"test_page_size.py"
,
60
),
TestFile
(
"test_page_size.py"
,
60
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment