Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a3f8d5dd
Commit
a3f8d5dd
authored
Dec 17, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori
parents
8d75f22e
f34eca5f
Changes
499
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2687 additions
and
105 deletions
+2687
-105
tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py
tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py
+276
-0
tests/engine/test_arg_utils.py
tests/engine/test_arg_utils.py
+18
-4
tests/entrypoints/openai/parser/test_harmony_utils.py
tests/entrypoints/openai/parser/test_harmony_utils.py
+666
-85
tests/entrypoints/openai/test_chat_error.py
tests/entrypoints/openai/test_chat_error.py
+227
-0
tests/entrypoints/openai/test_completion_error.py
tests/entrypoints/openai/test_completion_error.py
+216
-0
tests/entrypoints/openai/test_messages.py
tests/entrypoints/openai/test_messages.py
+6
-3
tests/entrypoints/openai/test_responses_error.py
tests/entrypoints/openai/test_responses_error.py
+89
-0
tests/entrypoints/openai/test_serving_chat.py
tests/entrypoints/openai/test_serving_chat.py
+645
-1
tests/entrypoints/openai/test_serving_engine.py
tests/entrypoints/openai/test_serving_engine.py
+1
-1
tests/entrypoints/openai/test_serving_responses.py
tests/entrypoints/openai/test_serving_responses.py
+3
-3
tests/entrypoints/openai/test_sparse_tensor_validation.py
tests/entrypoints/openai/test_sparse_tensor_validation.py
+342
-0
tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py
...ypoints/openai/tool_parsers/test_gigachat3_tool_parser.py
+1
-1
tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py
...ntrypoints/openai/tool_parsers/test_hermes_tool_parser.py
+1
-1
tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py
...ints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py
+1
-1
tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py
...oints/openai/tool_parsers/test_llama3_json_tool_parser.py
+1
-1
tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py
...s/openai/tool_parsers/test_llama4_pythonic_tool_parser.py
+1
-1
tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py
...entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py
+1
-1
tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py
...rypoints/openai/tool_parsers/test_pythonic_tool_parser.py
+1
-1
tests/entrypoints/openai/tool_parsers/utils.py
tests/entrypoints/openai/tool_parsers/utils.py
+1
-1
tests/entrypoints/openai/utils.py
tests/entrypoints/openai/utils.py
+190
-0
No files found.
tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py
0 → 100644
View file @
a3f8d5dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Test that the interaction between EPLB and FusedMoE Layer is okay for DP w/ NVFP4
from
dataclasses
import
dataclass
import
pytest
import
torch
from
tests.kernels.moe.utils
import
make_test_quant_config
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.distributed.eplb.rebalance_execute
import
rearrange_expert_weights_inplace
from
vllm.distributed.parallel_state
import
(
ensure_model_parallel_initialized
,
get_dp_group
,
)
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoE
from
vllm.model_executor.layers.quantization.modelopt
import
(
ModelOptNvFp4Config
,
ModelOptNvFp4FusedMoE
,
)
from
.eplb_utils
import
distributed_run
,
set_env_vars_and_device
@
dataclass
class
TestConfig
:
num_layers
:
int
num_experts
:
int
num_local_experts
:
int
num_topk
:
int
hidden_size
:
int
intermediate_size
:
int
num_tokens
:
int
def
make_fused_moe_layer
(
rank
:
int
,
layer_idx
:
int
,
test_config
:
TestConfig
,
)
->
FusedMoE
:
quant_config
=
None
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
quant_config
=
ModelOptNvFp4Config
(
is_checkpoint_nvfp4_serialized
=
True
,
kv_cache_quant_algo
=
None
,
exclude_modules
=
[],
)
fml
=
FusedMoE
(
num_experts
=
test_config
.
num_experts
,
top_k
=
test_config
.
num_topk
,
hidden_size
=
test_config
.
hidden_size
,
intermediate_size
=
test_config
.
intermediate_size
,
prefix
=
f
"dummy_layer_
{
layer_idx
}
"
,
activation
=
"silu"
,
is_act_and_mul
=
True
,
params_dtype
=
torch
.
bfloat16
,
quant_config
=
quant_config
,
)
nvfp4_fused_moe
=
ModelOptNvFp4FusedMoE
(
quant_config
,
fml
)
nvfp4_fused_moe
.
create_weights
(
fml
,
test_config
.
num_local_experts
,
test_config
.
hidden_size
,
test_config
.
intermediate_size
,
params_dtype
=
torch
.
uint8
,
global_num_experts
=
test_config
.
num_experts
,
)
fml
=
fml
.
to
(
device
)
w1_q
,
w2_q
,
quant_config
=
make_test_quant_config
(
test_config
.
num_local_experts
,
test_config
.
intermediate_size
,
test_config
.
hidden_size
,
in_dtype
=
torch
.
bfloat16
,
quant_dtype
=
"nvfp4"
,
block_shape
=
None
,
per_act_token_quant
=
False
,
)
fml
.
w13_weight
.
data
=
w1_q
fml
.
w2_weight
.
data
=
w2_q
fml
.
w2_input_scale
.
data
=
torch
.
randn_like
(
fml
.
w2_input_scale
.
data
)
/
5
fml
.
w13_input_scale
.
data
=
torch
.
randn_like
(
fml
.
w13_input_scale
.
data
)
/
5
fml
.
w2_weight_scale_2
.
data
=
torch
.
randn_like
(
fml
.
w2_weight_scale_2
.
data
)
/
5
fml
.
w13_weight_scale_2
.
data
=
torch
.
randn_like
(
fml
.
w13_weight_scale_2
.
data
)
/
5
fml
.
w2_weight_scale
.
data
=
(
torch
.
randn
(
fml
.
w2_weight_scale
.
data
.
shape
,
device
=
device
)
/
5
).
to
(
fml
.
w2_weight_scale
.
data
.
dtype
)
fml
.
w13_weight_scale
.
data
=
(
torch
.
randn
(
fml
.
w13_weight_scale
.
data
.
shape
,
device
=
device
)
/
5
).
to
(
fml
.
w13_weight_scale
.
data
.
dtype
)
nvfp4_fused_moe
.
process_weights_after_loading
(
fml
)
fml
.
maybe_init_modular_kernel
()
return
fml
def
_test_eplb_fml
(
env
,
world_size
:
int
,
test_config
:
TestConfig
):
set_env_vars_and_device
(
env
)
vllm_config
=
VllmConfig
()
vllm_config
.
parallel_config
.
data_parallel_size
=
world_size
vllm_config
.
parallel_config
.
enable_expert_parallel
=
True
with
set_current_vllm_config
(
vllm_config
):
ensure_model_parallel_initialized
(
tensor_model_parallel_size
=
1
,
pipeline_model_parallel_size
=
1
)
ep_group
=
get_dp_group
().
cpu_group
ep_rank
=
torch
.
distributed
.
get_rank
()
device
=
torch
.
device
(
f
"cuda:
{
ep_rank
}
"
)
fml_layers
=
[
make_fused_moe_layer
(
ep_rank
,
layer_idx
,
test_config
).
to
(
device
)
for
layer_idx
in
range
(
test_config
.
num_layers
)
]
rank_expert_weights
=
[
fml
.
get_expert_weights
()
for
fml
in
fml_layers
]
hidden_states
=
[]
router_logits
=
[]
for
layer_idx
in
range
(
test_config
.
num_layers
):
hidden_states
.
append
(
torch
.
randn
(
(
test_config
.
num_tokens
,
test_config
.
hidden_size
),
dtype
=
torch
.
bfloat16
,
device
=
device
,
)
)
router_logits
.
append
(
torch
.
randn
(
(
test_config
.
num_tokens
,
test_config
.
num_experts
),
dtype
=
torch
.
bfloat16
,
device
=
device
,
)
)
out_before_shuffle
=
[]
with
set_forward_context
(
{},
num_tokens
=
test_config
.
num_tokens
,
num_tokens_across_dp
=
torch
.
tensor
(
[
test_config
.
num_tokens
]
*
world_size
,
device
=
"cpu"
,
dtype
=
torch
.
int
),
vllm_config
=
vllm_config
,
):
for
lidx
,
fml
in
enumerate
(
fml_layers
):
out_before_shuffle
.
append
(
fml
(
hidden_states
[
lidx
].
clone
(),
router_logits
[
lidx
].
clone
())
)
indices
=
torch
.
zeros
(
test_config
.
num_layers
,
test_config
.
num_experts
,
dtype
=
torch
.
long
)
for
lidx
in
range
(
test_config
.
num_layers
):
indices
[
lidx
]
=
torch
.
Tensor
(
range
(
test_config
.
num_experts
))
shuffled_indices
=
torch
.
zeros_like
(
indices
)
for
lidx
in
range
(
test_config
.
num_layers
):
shuffled_indices
[
lidx
]
=
torch
.
randperm
(
test_config
.
num_experts
)
rearrange_expert_weights_inplace
(
indices
,
shuffled_indices
,
rank_expert_weights
,
ep_group
,
is_profile
=
False
,
)
num_global_experts
=
test_config
.
num_experts
logical_to_physical_map_list
=
[]
for
lidx
,
fml
in
enumerate
(
fml_layers
):
physical_to_logical_map
=
shuffled_indices
[
lidx
].
to
(
device
)
logical_to_physical_map
=
torch
.
empty
(
(
num_global_experts
,),
dtype
=
torch
.
int32
,
device
=
device
)
logical_to_physical_map
[
physical_to_logical_map
]
=
torch
.
arange
(
0
,
num_global_experts
,
dtype
=
torch
.
int32
,
device
=
device
)
logical_to_physical_map_list
.
append
(
logical_to_physical_map
.
reshape
(
num_global_experts
,
1
)
)
logical_to_physical_map
=
torch
.
stack
(
logical_to_physical_map_list
)
for
lidx
,
fml
in
enumerate
(
fml_layers
):
logical_replica_count
=
torch
.
ones
(
(
test_config
.
num_layers
,
num_global_experts
),
dtype
=
torch
.
int32
,
device
=
device
,
)
fml
.
enable_eplb
=
True
fml
.
set_eplb_state
(
lidx
,
torch
.
zeros
(
(
test_config
.
num_layers
,
num_global_experts
),
dtype
=
torch
.
int32
,
device
=
device
,
),
logical_to_physical_map
,
logical_replica_count
,
)
out_after_shuffle
=
[]
with
set_forward_context
(
{},
num_tokens
=
test_config
.
num_tokens
,
num_tokens_across_dp
=
torch
.
tensor
(
[
test_config
.
num_tokens
]
*
world_size
,
device
=
"cpu"
,
dtype
=
torch
.
int
),
vllm_config
=
vllm_config
,
):
for
lidx
,
fml
in
enumerate
(
fml_layers
):
out_after_shuffle
.
append
(
fml
(
hidden_states
[
lidx
].
clone
(),
router_logits
[
lidx
].
clone
())
)
for
lidx
in
range
(
test_config
.
num_layers
):
torch
.
testing
.
assert_close
(
out_before_shuffle
[
lidx
],
out_after_shuffle
[
lidx
],
atol
=
1e-1
,
rtol
=
1e-1
)
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"num_layers"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
256
])
@
pytest
.
mark
.
parametrize
(
"intermediate_size"
,
[
256
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
256
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"latency"
,
"throughput"
])
def
test_eplb_fml
(
world_size
:
int
,
num_layers
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
num_tokens
:
int
,
backend
:
str
,
monkeypatch
,
):
monkeypatch
.
setenv
(
"VLLM_USE_FLASHINFER_MOE_FP4"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_FLASHINFER_MOE_BACKEND"
,
backend
)
if
torch
.
cuda
.
device_count
()
<
world_size
:
pytest
.
skip
(
f
"Need at least
{
world_size
}
GPUs to run the test"
)
num_local_experts
=
num_experts
//
world_size
num_topk
=
4
test_config
=
TestConfig
(
num_layers
=
num_layers
,
num_experts
=
num_experts
,
num_local_experts
=
num_local_experts
,
num_topk
=
num_topk
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
num_tokens
=
num_tokens
,
)
distributed_run
(
_test_eplb_fml
,
world_size
,
test_config
,
)
tests/engine/test_arg_utils.py
View file @
a3f8d5dd
...
@@ -350,21 +350,35 @@ def test_human_readable_model_len():
...
@@ -350,21 +350,35 @@ def test_human_readable_model_len():
assert
args
.
max_model_len
==
1_000_000
assert
args
.
max_model_len
==
1_000_000
args
=
parser
.
parse_args
([
"--max-model-len"
,
"10k"
])
args
=
parser
.
parse_args
([
"--max-model-len"
,
"10k"
])
assert
args
.
max_model_len
==
10_000
assert
args
.
max_model_len
==
10_000
args
=
parser
.
parse_args
([
"--max-model-len"
,
"2g"
])
assert
args
.
max_model_len
==
2_000_000_000
args
=
parser
.
parse_args
([
"--max-model-len"
,
"2t"
])
assert
args
.
max_model_len
==
2_000_000_000_000
# Capital
# Capital
args
=
parser
.
parse_args
([
"--max-model-len"
,
"3K"
])
args
=
parser
.
parse_args
([
"--max-model-len"
,
"3K"
])
assert
args
.
max_model_len
==
10
24
*
3
assert
args
.
max_model_len
==
2
**
10
*
3
args
=
parser
.
parse_args
([
"--max-model-len"
,
"10M"
])
args
=
parser
.
parse_args
([
"--max-model-len"
,
"10M"
])
assert
args
.
max_model_len
==
2
**
20
*
10
assert
args
.
max_model_len
==
2
**
20
*
10
args
=
parser
.
parse_args
([
"--max-model-len"
,
"4G"
])
assert
args
.
max_model_len
==
2
**
30
*
4
args
=
parser
.
parse_args
([
"--max-model-len"
,
"4T"
])
assert
args
.
max_model_len
==
2
**
40
*
4
# Decimal values
# Decimal values
args
=
parser
.
parse_args
([
"--max-model-len"
,
"10.2k"
])
args
=
parser
.
parse_args
([
"--max-model-len"
,
"10.2k"
])
assert
args
.
max_model_len
==
10200
assert
args
.
max_model_len
==
10200
# ..truncated to the nearest int
# ..truncated to the nearest int
args
=
parser
.
parse_args
([
"--max-model-len"
,
"10.212345k"
])
args
=
parser
.
parse_args
([
"--max-model-len"
,
"10.212345
1234567
k"
])
assert
args
.
max_model_len
==
10212
assert
args
.
max_model_len
==
10212
args
=
parser
.
parse_args
([
"--max-model-len"
,
"10.2123451234567m"
])
assert
args
.
max_model_len
==
10212345
args
=
parser
.
parse_args
([
"--max-model-len"
,
"10.2123451234567g"
])
assert
args
.
max_model_len
==
10212345123
args
=
parser
.
parse_args
([
"--max-model-len"
,
"10.2123451234567t"
])
assert
args
.
max_model_len
==
10212345123456
# Invalid (do not allow decimals with binary multipliers)
# Invalid (do not allow decimals with binary multipliers)
for
invalid
in
[
"1a"
,
"pwd"
,
"10.24"
,
"1.23M"
]:
for
invalid
in
[
"1a"
,
"pwd"
,
"10.24"
,
"1.23M"
,
"1.22T"
]:
with
pytest
.
raises
(
ArgumentError
):
with
pytest
.
raises
(
ArgumentError
):
args
=
parser
.
parse_args
([
"--max-model-len"
,
invalid
])
parser
.
parse_args
([
"--max-model-len"
,
invalid
])
tests/entrypoints/openai/parser/test_harmony_utils.py
View file @
a3f8d5dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
from
openai.types.responses
import
ResponseFunctionToolCall
,
ResponseReasoningItem
from
openai.types.responses
import
ResponseFunctionToolCall
,
ResponseReasoningItem
from
openai.types.responses.response_output_item
import
McpCall
from
openai.types.responses.response_output_item
import
McpCall
from
openai_harmony
import
Author
,
Message
,
Role
,
TextContent
from
openai_harmony
import
Author
,
Message
,
Role
,
TextContent
from
tests.entrypoints.openai.utils
import
verify_harmony_messages
from
vllm.entrypoints.openai.parser.harmony_utils
import
(
from
vllm.entrypoints.openai.parser.harmony_utils
import
(
auto_drop_analysis_messages
,
get_encoding
,
has_custom_tools
,
has_custom_tools
,
parse_chat_input_to_harmony_message
,
parse_chat_output
,
parse_input_to_harmony_message
,
parse_input_to_harmony_message
,
parse_output_message
,
parse_output_message
,
)
)
class
TestParseInputToHarmonyMessage
:
class
TestCommonParseInputToHarmonyMessage
:
"""Tests for parse_input_to_harmony_message function."""
"""
Tests for scenarios that are common to both Chat Completion
parse_chat_input_to_harmony_message and Responsees API
parse_input_to_harmony_message functions.
"""
@
pytest
.
fixture
(
params
=
[
parse_chat_input_to_harmony_message
,
parse_input_to_harmony_message
]
)
def
parse_function
(
self
,
request
):
return
request
.
param
def
test_assistant_message_with_tool_calls
(
self
):
def
test_assistant_message_with_tool_calls
(
self
,
parse_function
):
"""Test parsing assistant message with tool calls."""
"""Test parsing assistant message with tool calls."""
chat_msg
=
{
chat_msg
=
{
"role"
:
"assistant"
,
"role"
:
"assistant"
,
...
@@ -35,7 +51,7 @@ class TestParseInputToHarmonyMessage:
...
@@ -35,7 +51,7 @@ class TestParseInputToHarmonyMessage:
],
],
}
}
messages
=
parse_
input_to_harmony_message
(
chat_msg
)
messages
=
parse_
function
(
chat_msg
)
assert
len
(
messages
)
==
2
assert
len
(
messages
)
==
2
...
@@ -53,7 +69,7 @@ class TestParseInputToHarmonyMessage:
...
@@ -53,7 +69,7 @@ class TestParseInputToHarmonyMessage:
assert
messages
[
1
].
recipient
==
"functions.search_web"
assert
messages
[
1
].
recipient
==
"functions.search_web"
assert
messages
[
1
].
content_type
==
"json"
assert
messages
[
1
].
content_type
==
"json"
def
test_assistant_message_with_empty_tool_call_arguments
(
self
):
def
test_assistant_message_with_empty_tool_call_arguments
(
self
,
parse_function
):
"""Test parsing assistant message with tool call having None arguments."""
"""Test parsing assistant message with tool call having None arguments."""
chat_msg
=
{
chat_msg
=
{
"role"
:
"assistant"
,
"role"
:
"assistant"
,
...
@@ -67,12 +83,152 @@ class TestParseInputToHarmonyMessage:
...
@@ -67,12 +83,152 @@ class TestParseInputToHarmonyMessage:
],
],
}
}
messages
=
parse_
input_to_harmony_message
(
chat_msg
)
messages
=
parse_
function
(
chat_msg
)
assert
len
(
messages
)
==
1
assert
len
(
messages
)
==
1
assert
messages
[
0
].
content
[
0
].
text
==
""
assert
messages
[
0
].
content
[
0
].
text
==
""
assert
messages
[
0
].
recipient
==
"functions.get_current_time"
assert
messages
[
0
].
recipient
==
"functions.get_current_time"
def
test_system_message
(
self
,
parse_function
):
"""Test parsing system message."""
chat_msg
=
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant"
,
}
messages
=
parse_function
(
chat_msg
)
assert
len
(
messages
)
==
1
# System messages are converted using Message.from_dict
# which should preserve the role
assert
messages
[
0
].
author
.
role
==
Role
.
SYSTEM
def
test_developer_message
(
self
,
parse_function
):
"""Test parsing developer message."""
chat_msg
=
{
"role"
:
"developer"
,
"content"
:
"Use concise language"
,
}
messages
=
parse_function
(
chat_msg
)
assert
len
(
messages
)
==
1
assert
messages
[
0
].
author
.
role
==
Role
.
DEVELOPER
def
test_user_message_with_string_content
(
self
,
parse_function
):
"""Test parsing user message with string content."""
chat_msg
=
{
"role"
:
"user"
,
"content"
:
"What's the weather in San Francisco?"
,
}
messages
=
parse_function
(
chat_msg
)
assert
len
(
messages
)
==
1
assert
messages
[
0
].
author
.
role
==
Role
.
USER
assert
messages
[
0
].
content
[
0
].
text
==
"What's the weather in San Francisco?"
def
test_user_message_with_array_content
(
self
,
parse_function
):
"""Test parsing user message with array content."""
chat_msg
=
{
"role"
:
"user"
,
"content"
:
[
{
"text"
:
"What's in this image? "
},
{
"text"
:
"Please describe it."
},
],
}
messages
=
parse_function
(
chat_msg
)
assert
len
(
messages
)
==
1
assert
messages
[
0
].
author
.
role
==
Role
.
USER
assert
len
(
messages
[
0
].
content
)
==
2
assert
messages
[
0
].
content
[
0
].
text
==
"What's in this image? "
assert
messages
[
0
].
content
[
1
].
text
==
"Please describe it."
def
test_assistant_message_with_string_content
(
self
,
parse_function
):
"""Test parsing assistant message with string content (no tool calls)."""
chat_msg
=
{
"role"
:
"assistant"
,
"content"
:
"Hello! How can I help you today?"
,
}
messages
=
parse_function
(
chat_msg
)
assert
len
(
messages
)
==
1
assert
messages
[
0
].
author
.
role
==
Role
.
ASSISTANT
assert
messages
[
0
].
content
[
0
].
text
==
"Hello! How can I help you today?"
def
test_pydantic_model_input
(
self
,
parse_function
):
"""Test parsing Pydantic model input (has model_dump method)."""
class
MockPydanticModel
:
def
model_dump
(
self
,
exclude_none
=
True
):
return
{
"role"
:
"user"
,
"content"
:
"Test message"
,
}
chat_msg
=
MockPydanticModel
()
messages
=
parse_function
(
chat_msg
)
assert
len
(
messages
)
==
1
assert
messages
[
0
].
author
.
role
==
Role
.
USER
assert
messages
[
0
].
content
[
0
].
text
==
"Test message"
def
test_tool_call_with_missing_function_fields
(
self
,
parse_function
):
"""Test parsing tool call with missing name or arguments."""
chat_msg
=
{
"role"
:
"assistant"
,
"tool_calls"
:
[
{
"function"
:
{}
# Missing both name and arguments
}
],
}
messages
=
parse_function
(
chat_msg
)
assert
len
(
messages
)
==
1
assert
messages
[
0
].
recipient
==
"functions."
assert
messages
[
0
].
content
[
0
].
text
==
""
def
test_array_content_with_missing_text
(
self
,
parse_function
):
"""Test parsing array content where text field is missing."""
chat_msg
=
{
"role"
:
"user"
,
"content"
:
[
{},
# Missing text field
{
"text"
:
"actual text"
},
],
}
messages
=
parse_function
(
chat_msg
)
assert
len
(
messages
)
==
1
assert
len
(
messages
[
0
].
content
)
==
2
assert
messages
[
0
].
content
[
0
].
text
==
""
assert
messages
[
0
].
content
[
1
].
text
==
"actual text"
class
TestParseInputToHarmonyMessage
:
"""
Tests for scenarios that are specific to the Responses API
parse_input_to_harmony_message function.
"""
def
test_message_with_empty_content
(
self
):
"""Test parsing message with empty string content."""
chat_msg
=
{
"role"
:
"user"
,
"content"
:
""
,
}
messages
=
parse_input_to_harmony_message
(
chat_msg
)
assert
len
(
messages
)
==
1
assert
messages
[
0
].
content
[
0
].
text
==
""
def
test_tool_message_with_string_content
(
self
):
def
test_tool_message_with_string_content
(
self
):
"""Test parsing tool message with string content."""
"""Test parsing tool message with string content."""
chat_msg
=
{
chat_msg
=
{
...
@@ -111,6 +267,7 @@ class TestParseInputToHarmonyMessage:
...
@@ -111,6 +267,7 @@ class TestParseInputToHarmonyMessage:
assert
len
(
messages
)
==
1
assert
len
(
messages
)
==
1
assert
messages
[
0
].
author
.
role
==
Role
.
TOOL
assert
messages
[
0
].
author
.
role
==
Role
.
TOOL
assert
messages
[
0
].
author
.
name
==
"functions.search_results"
assert
messages
[
0
].
content
[
0
].
text
==
"Result 1: Result 2: Result 3"
assert
messages
[
0
].
content
[
0
].
text
==
"Result 1: Result 2: Result 3"
def
test_tool_message_with_empty_content
(
self
):
def
test_tool_message_with_empty_content
(
self
):
...
@@ -124,140 +281,564 @@ class TestParseInputToHarmonyMessage:
...
@@ -124,140 +281,564 @@ class TestParseInputToHarmonyMessage:
messages
=
parse_input_to_harmony_message
(
chat_msg
)
messages
=
parse_input_to_harmony_message
(
chat_msg
)
assert
len
(
messages
)
==
1
assert
len
(
messages
)
==
1
assert
messages
[
0
].
author
.
role
==
Role
.
TOOL
assert
messages
[
0
].
author
.
name
==
"functions.empty_tool"
assert
messages
[
0
].
content
[
0
].
text
==
""
assert
messages
[
0
].
content
[
0
].
text
==
""
def
test_system_message
(
self
):
"""Test parsing system message."""
class
TestParseChatInputToHarmonyMessage
:
"""
Tests for scenarios that are specific to the Chat Completion API
parse_chat_input_to_harmony_message function.
"""
def
test_user_message_with_empty_content
(
self
):
chat_msg
=
{
chat_msg
=
{
"role"
:
"
system
"
,
"role"
:
"
user
"
,
"content"
:
"
You are a helpful assistant
"
,
"content"
:
""
,
}
}
messages
=
parse_input_to_harmony_message
(
chat_msg
)
messages
=
parse_
chat_
input_to_harmony_message
(
chat_msg
)
assert
len
(
messages
)
==
1
verify_harmony_messages
(
# System messages are converted using Message.from_dict
messages
,
# which should preserve the role
[
assert
messages
[
0
].
author
.
role
==
Role
.
SYSTEM
{
"role"
:
"user"
,
"content"
:
""
,
},
],
)
def
test_developer_message
(
self
):
def
test_user_message_with_none_content
(
self
):
"""Test parsing developer message."""
chat_msg
=
{
chat_msg
=
{
"role"
:
"
develop
er"
,
"role"
:
"
us
er"
,
"content"
:
"Use concise language"
,
"content"
:
None
,
}
}
messages
=
parse_input_to_harmony_message
(
chat_msg
)
messages
=
parse_
chat_
input_to_harmony_message
(
chat_msg
)
assert
len
(
messages
)
==
1
verify_harmony_messages
(
assert
messages
[
0
].
author
.
role
==
Role
.
DEVELOPER
messages
,
[
{
"role"
:
"user"
,
"content"
:
""
,
},
],
)
def
test_user_message_with_string_content
(
self
):
def
test_assistant_message_with_empty_content
(
self
):
"""Test parsing user message with string content."""
chat_msg
=
{
chat_msg
=
{
"role"
:
"
user
"
,
"role"
:
"
assistant
"
,
"content"
:
"
What's the weather in San Francisco?
"
,
"content"
:
""
,
}
}
messages
=
parse_input_to_harmony_message
(
chat_msg
)
messages
=
parse_
chat_
input_to_harmony_message
(
chat_msg
)
assert
len
(
messages
)
==
1
assert
len
(
messages
)
==
0
assert
messages
[
0
].
author
.
role
==
Role
.
USER
assert
messages
[
0
].
content
[
0
].
text
==
"What's the weather in San Francisco?"
def
test_user_message_with_array_content
(
self
):
def
test_assistant_message_with_none_content
(
self
):
"""Test parsing user message with array content."""
chat_msg
=
{
chat_msg
=
{
"role"
:
"user"
,
"role"
:
"assistant"
,
"content"
:
[
"content"
:
None
,
{
"text"
:
"What's in this image? "
},
}
{
"text"
:
"Please describe it."
},
messages
=
parse_chat_input_to_harmony_message
(
chat_msg
)
assert
len
(
messages
)
==
0
def
test_assistant_message_with_content_but_empty_reasoning
(
self
):
chat_msg
=
{
"role"
:
"assistant"
,
"content"
:
"The answer is 4."
,
"reasoning"
:
""
,
}
messages
=
parse_chat_input_to_harmony_message
(
chat_msg
)
verify_harmony_messages
(
messages
,
[
{
"role"
:
"assistant"
,
"channel"
:
"final"
,
"content"
:
"The answer is 4."
,
},
],
],
)
def
test_assistant_message_with_reasoning_but_empty_content
(
self
):
chat_msg
=
{
"role"
:
"assistant"
,
"reasoning"
:
"I'm thinking about the user's question."
,
"content"
:
""
,
}
}
messages
=
parse_input_to_harmony_message
(
chat_msg
)
messages
=
parse_
chat_
input_to_harmony_message
(
chat_msg
)
assert
len
(
messages
)
==
1
verify_harmony_messages
(
assert
messages
[
0
].
author
.
role
==
Role
.
USER
messages
,
assert
len
(
messages
[
0
].
content
)
==
2
[
assert
messages
[
0
].
content
[
0
].
text
==
"What's in this image? "
{
assert
messages
[
0
].
content
[
1
].
text
==
"Please describe it."
"role"
:
"assistant"
,
"channel"
:
"analysis"
,
"content"
:
"I'm thinking about the user's question."
,
},
],
)
def
test_assistant_message_with_string_content
(
self
):
def
test_assistant_message_with_reasoning_but_none_content
(
self
):
"""Test parsing assistant message with string content (no tool calls)."""
chat_msg
=
{
chat_msg
=
{
"role"
:
"assistant"
,
"role"
:
"assistant"
,
"content"
:
"Hello! How can I help you today?"
,
"reasoning"
:
"I'm thinking about the user's question."
,
"content"
:
None
,
}
}
messages
=
parse_input_to_harmony_message
(
chat_msg
)
messages
=
parse_
chat_
input_to_harmony_message
(
chat_msg
)
assert
len
(
messages
)
==
1
verify_harmony_messages
(
assert
messages
[
0
].
author
.
role
==
Role
.
ASSISTANT
messages
,
assert
messages
[
0
].
content
[
0
].
text
==
"Hello! How can I help you today?"
[
{
"role"
:
"assistant"
,
"channel"
:
"analysis"
,
"content"
:
"I'm thinking about the user's question."
,
},
],
)
def
test_pydantic_model_input
(
self
):
def
test_assistant_message_with_tool_calls_but_no_content
(
self
):
"""Test parsing Pydantic model input (has model_dump method)."""
chat_msg
=
{
"role"
:
"assistant"
,
"tool_calls"
:
[
{
"function"
:
{
"name"
:
"get_weather"
,
"arguments"
:
'{"location": "San Francisco"}'
,
}
}
],
}
class
MockPydanticModel
:
messages
=
parse_chat_input_to_harmony_message
(
chat_msg
)
def
model_dump
(
self
,
exclude_none
=
True
):
return
{
verify_harmony_messages
(
"role"
:
"user"
,
messages
,
"content"
:
"Test message"
,
[
{
"role"
:
"assistant"
,
"channel"
:
"commentary"
,
"recipient"
:
"functions.get_weather"
,
"content"
:
'{"location": "San Francisco"}'
,
"content_type"
:
"json"
,
},
],
)
def
test_assistant_message_with_tool_calls_and_content
(
self
):
chat_msg
=
{
"role"
:
"assistant"
,
"tool_calls"
:
[
{
"function"
:
{
"name"
:
"get_weather"
,
"arguments"
:
'{"location": "San Francisco"}'
,
}
}
}
],
"content"
:
"I'll call the tool."
,
}
chat_msg
=
MockPydanticModel
()
messages
=
parse_chat_input_to_harmony_message
(
chat_msg
)
messages
=
parse_input_to_harmony_message
(
chat_msg
)
assert
len
(
messages
)
==
1
verify_harmony_messages
(
assert
messages
[
0
].
author
.
role
==
Role
.
USER
messages
,
assert
messages
[
0
].
content
[
0
].
text
==
"Test message"
[
{
"role"
:
"assistant"
,
"channel"
:
"commentary"
,
"content"
:
"I'll call the tool."
,
},
{
"role"
:
"assistant"
,
"channel"
:
"commentary"
,
"recipient"
:
"functions.get_weather"
,
"content"
:
'{"location": "San Francisco"}'
,
"content_type"
:
"json"
,
},
],
)
def
test_message_with_empty_content
(
self
):
def
test_assistant_message_with_tool_calls_and_reasoning
(
self
):
"""Test parsing message with empty string content."""
chat_msg
=
{
chat_msg
=
{
"role"
:
"user"
,
"role"
:
"assistant"
,
"content"
:
""
,
"tool_calls"
:
[
{
"function"
:
{
"name"
:
"get_weather"
,
"arguments"
:
'{"location": "San Francisco"}'
,
}
}
],
"reasoning"
:
"I should use the get_weather tool."
,
}
}
messages
=
parse_input_to_harmony_message
(
chat_msg
)
messages
=
parse_
chat_
input_to_harmony_message
(
chat_msg
)
assert
len
(
messages
)
==
1
verify_harmony_messages
(
assert
messages
[
0
].
content
[
0
].
text
==
""
messages
,
[
{
"role"
:
"assistant"
,
"channel"
:
"analysis"
,
"content"
:
"I should use the get_weather tool."
,
},
{
"role"
:
"assistant"
,
"channel"
:
"commentary"
,
"recipient"
:
"functions.get_weather"
,
"content"
:
'{"location": "San Francisco"}'
,
"content_type"
:
"json"
,
},
],
)
def
test_tool_call_with_missing_function_fields
(
self
):
def
test_assistant_message_with_tool_calls_and_reasoning_and_content
(
self
):
"""Test parsing tool call with missing name or arguments."""
chat_msg
=
{
chat_msg
=
{
"role"
:
"assistant"
,
"role"
:
"assistant"
,
"tool_calls"
:
[
"tool_calls"
:
[
{
{
"function"
:
{}
# Missing both name and arguments
"function"
:
{
"name"
:
"get_weather"
,
"arguments"
:
'{"location": "San Francisco"}'
,
}
}
}
],
],
"reasoning"
:
"I should use the get_weather tool."
,
"content"
:
"I'll call the tool."
,
}
}
messages
=
parse_input_to_harmony_message
(
chat_msg
)
messages
=
parse_
chat_
input_to_harmony_message
(
chat_msg
)
assert
len
(
messages
)
==
1
verify_harmony_messages
(
assert
messages
[
0
].
recipient
==
"functions."
messages
,
assert
messages
[
0
].
content
[
0
].
text
==
""
[
{
"role"
:
"assistant"
,
"channel"
:
"commentary"
,
"content"
:
"I'll call the tool."
,
},
{
"role"
:
"assistant"
,
"channel"
:
"analysis"
,
"content"
:
"I should use the get_weather tool."
,
},
{
"role"
:
"assistant"
,
"channel"
:
"commentary"
,
"recipient"
:
"functions.get_weather"
,
"content"
:
'{"location": "San Francisco"}'
,
"content_type"
:
"json"
,
},
],
)
def
test_array_content_with_missing_text
(
self
):
def
test_tool_message_with_string_content
(
self
):
"""Test parsing array content where text field is missing."""
tool_id_names
=
{
"call_123"
:
"get_weather"
,
}
chat_msg
=
{
chat_msg
=
{
"role"
:
"user"
,
"role"
:
"tool"
,
"tool_call_id"
:
"call_123"
,
"content"
:
"The weather in San Francisco is sunny, 72°F"
,
}
messages
=
parse_chat_input_to_harmony_message
(
chat_msg
,
tool_id_names
=
tool_id_names
)
verify_harmony_messages
(
messages
,
[
{
"role"
:
"tool"
,
"name"
:
"functions.get_weather"
,
"content"
:
"The weather in San Francisco is sunny, 72°F"
,
"channel"
:
"commentary"
,
},
],
)
def
test_tool_message_with_array_content
(
self
):
tool_id_names
=
{
"call_123"
:
"search_results"
,
}
chat_msg
=
{
"role"
:
"tool"
,
"tool_call_id"
:
"call_123"
,
"content"
:
[
"content"
:
[
{},
# Missing text field
{
"type"
:
"text"
,
"text"
:
"Result 1: "
},
{
"text"
:
"actual text"
},
{
"type"
:
"text"
,
"text"
:
"Result 2: "
},
{
"type"
:
"image"
,
"url"
:
"http://example.com/img.png"
,
},
# Should be ignored
{
"type"
:
"text"
,
"text"
:
"Result 3"
},
],
],
}
}
messages
=
parse_input_to_harmony_message
(
chat_msg
)
messages
=
parse_chat_input_to_harmony_message
(
chat_msg
,
tool_id_names
=
tool_id_names
)
assert
len
(
messages
)
==
1
verify_harmony_messages
(
assert
len
(
messages
[
0
].
content
)
==
2
messages
,
assert
messages
[
0
].
content
[
0
].
text
==
""
[
assert
messages
[
0
].
content
[
1
].
text
==
"actual text"
{
"role"
:
"tool"
,
"name"
:
"functions.search_results"
,
"content"
:
"Result 1: Result 2: Result 3"
,
"channel"
:
"commentary"
,
},
],
)
def
test_tool_message_with_empty_content
(
self
):
tool_id_names
=
{
"call_123"
:
"empty_tool"
,
}
chat_msg
=
{
"role"
:
"tool"
,
"tool_call_id"
:
"call_123"
,
"content"
:
""
,
}
messages
=
parse_chat_input_to_harmony_message
(
chat_msg
,
tool_id_names
=
tool_id_names
)
verify_harmony_messages
(
messages
,
[
{
"role"
:
"tool"
,
"name"
:
"functions.empty_tool"
,
"content"
:
""
,
"channel"
:
"commentary"
,
},
],
)
def
test_tool_message_with_none_content
(
self
):
tool_id_names
=
{
"call_123"
:
"empty_tool"
,
}
chat_msg
=
{
"role"
:
"tool"
,
"tool_call_id"
:
"call_123"
,
"content"
:
None
,
}
messages
=
parse_chat_input_to_harmony_message
(
chat_msg
,
tool_id_names
=
tool_id_names
)
verify_harmony_messages
(
messages
,
[
{
"role"
:
"tool"
,
"name"
:
"functions.empty_tool"
,
"content"
:
""
,
"channel"
:
"commentary"
,
},
],
)
class
TestAutoDropAnalysisMessages
:
def
test_no_analysis_messages
(
self
)
->
None
:
messages
=
[
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
"The answer is 4."
).
with_channel
(
"final"
),
]
cleaned_messages
=
auto_drop_analysis_messages
(
messages
)
assert
cleaned_messages
==
messages
def
test_only_analysis_message
(
self
)
->
None
:
messages
=
[
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
"I'm thinking about the user's question."
).
with_channel
(
"analysis"
),
]
cleaned_messages
=
auto_drop_analysis_messages
(
messages
)
assert
cleaned_messages
==
messages
def
test_multiple_analysis_messages_without_final_message
(
self
)
->
None
:
messages
=
[
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
"I'm thinking about the user's question."
).
with_channel
(
"analysis"
),
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
"I'm thinking more."
).
with_channel
(
"analysis"
),
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
"I'm thinking even more."
).
with_channel
(
"analysis"
),
]
cleaned_messages
=
auto_drop_analysis_messages
(
messages
)
assert
cleaned_messages
==
messages
def
test_only_final_message
(
self
)
->
None
:
messages
=
[
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
"The answer is 4."
).
with_channel
(
"final"
),
]
cleaned_messages
=
auto_drop_analysis_messages
(
messages
)
assert
cleaned_messages
==
messages
def
test_drops_one_analysis_messages_before_final_message
(
self
)
->
None
:
messages
=
[
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
"I'm thinking about the user's question."
).
with_channel
(
"analysis"
),
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
"The answer is 4."
).
with_channel
(
"final"
),
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
"I should think harder."
).
with_channel
(
"analysis"
),
]
cleaned_messages
=
auto_drop_analysis_messages
(
messages
)
# Should have dropped the first analysis message
assert
cleaned_messages
==
messages
[
1
:]
def
test_drops_all_analysis_messages_before_final_message
(
self
)
->
None
:
messages
=
[
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
"I'm thinking about the user's question."
).
with_channel
(
"analysis"
),
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
"I'm thinking more."
).
with_channel
(
"analysis"
),
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
"I'm thinking even more."
).
with_channel
(
"analysis"
),
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
"The answer is 4."
).
with_channel
(
"final"
),
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
"I should think harder."
).
with_channel
(
"analysis"
),
]
cleaned_messages
=
auto_drop_analysis_messages
(
messages
)
# Should have dropped the first 3 analysis messages
assert
cleaned_messages
==
messages
[
3
:]
def
test_multiple_analysis_messages_with_multiple_final_messages
(
self
)
->
None
:
messages
=
[
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
"I'm thinking about the user's question."
).
with_channel
(
"analysis"
),
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
"I'm thinking more."
).
with_channel
(
"analysis"
),
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
"I'm thinking even more."
).
with_channel
(
"analysis"
),
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
"The answer is 4."
).
with_channel
(
"final"
),
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
"I should think harder."
).
with_channel
(
"analysis"
),
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
"The answer is 5."
).
with_channel
(
"final"
),
]
cleaned_messages
=
auto_drop_analysis_messages
(
messages
)
# Should have dropped all those analysis messages
assert
len
(
cleaned_messages
)
==
2
assert
cleaned_messages
[
0
].
content
[
0
].
text
==
"The answer is 4."
assert
cleaned_messages
[
1
].
content
[
0
].
text
==
"The answer is 5."
def
test_drops_non_assistant_analysis_messages
(
self
)
->
None
:
messages
=
[
Message
.
from_role_and_content
(
Role
.
TOOL
,
"The tool thinks we should think harder."
).
with_channel
(
"analysis"
),
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
"The answer is 4."
).
with_channel
(
"final"
),
]
cleaned_messages
=
auto_drop_analysis_messages
(
messages
)
# Should have dropped the analysis message
assert
cleaned_messages
==
messages
[
1
:]
class
TestParseChatOutput
:
def
test_parse_chat_output_interrupted_first_message
(
self
)
->
None
:
harmony_str
=
"<|channel|>final<|message|>I'm in the middle of answering"
token_ids
=
get_encoding
().
encode
(
harmony_str
,
allowed_special
=
"all"
)
reasoning
,
final_content
,
_
=
parse_chat_output
(
token_ids
)
assert
reasoning
is
None
assert
final_content
==
"I'm in the middle of answering"
def
test_parse_chat_output_interrupted_reasoning_first_message
(
self
)
->
None
:
harmony_str
=
"<|channel|>analysis<|message|>I'm in the middle of thinking"
token_ids
=
get_encoding
().
encode
(
harmony_str
,
allowed_special
=
"all"
)
reasoning
,
final_content
,
_
=
parse_chat_output
(
token_ids
)
assert
reasoning
==
"I'm in the middle of thinking"
assert
final_content
is
None
def
test_parse_chat_output_complete_reasoning_interrupted_content
(
self
)
->
None
:
harmony_str
=
(
"<|channel|>analysis<|message|>I'm thinking.<|end|>"
"<|start|>assistant<|channel|>final"
"<|message|>I'm in the middle of answering"
)
token_ids
=
get_encoding
().
encode
(
harmony_str
,
allowed_special
=
"all"
)
reasoning
,
final_content
,
_
=
parse_chat_output
(
token_ids
)
assert
reasoning
==
"I'm thinking."
assert
final_content
==
"I'm in the middle of answering"
def
test_parse_chat_output_complete_content
(
self
)
->
None
:
harmony_str
=
"<|channel|>final<|message|>The answer is 4.<|end|>"
token_ids
=
get_encoding
().
encode
(
harmony_str
,
allowed_special
=
"all"
)
reasoning
,
final_content
,
_
=
parse_chat_output
(
token_ids
)
assert
reasoning
is
None
assert
final_content
==
"The answer is 4."
def
test_parse_chat_output_complete_commentary
(
self
)
->
None
:
harmony_str
=
(
"<|channel|>commentary<|message|>I need to call some tools.<|end|>"
)
token_ids
=
get_encoding
().
encode
(
harmony_str
,
allowed_special
=
"all"
)
reasoning
,
final_content
,
_
=
parse_chat_output
(
token_ids
)
assert
reasoning
is
None
assert
final_content
==
"I need to call some tools."
def
test_parse_chat_output_complete_reasoning
(
self
)
->
None
:
harmony_str
=
(
"<|channel|>analysis<|message|>I've thought hard about this.<|end|>"
)
token_ids
=
get_encoding
().
encode
(
harmony_str
,
allowed_special
=
"all"
)
reasoning
,
final_content
,
_
=
parse_chat_output
(
token_ids
)
assert
reasoning
==
"I've thought hard about this."
assert
final_content
is
None
def
test_parse_chat_output_complete_reasoning_and_content
(
self
)
->
None
:
harmony_str
=
(
"<|channel|>analysis<|message|>I've thought hard about this.<|end|>"
"<|start|>assistant<|channel|>final<|message|>The answer is 4.<|end|>"
)
token_ids
=
get_encoding
().
encode
(
harmony_str
,
allowed_special
=
"all"
)
reasoning
,
final_content
,
_
=
parse_chat_output
(
token_ids
)
assert
reasoning
==
"I've thought hard about this."
assert
final_content
==
"The answer is 4."
class
TestParseOutputMessage
:
class
TestParseOutputMessage
:
...
...
tests/entrypoints/openai/test_chat_error.py
0 → 100644
View file @
a3f8d5dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
,
field
from
http
import
HTTPStatus
from
typing
import
Any
from
unittest.mock
import
AsyncMock
,
MagicMock
import
pytest
from
vllm.config.multimodal
import
MultiModalConfig
from
vllm.entrypoints.openai.protocol
import
ChatCompletionRequest
,
ErrorResponse
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_models
import
BaseModelPath
,
OpenAIServingModels
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.v1.engine.async_llm
import
AsyncLLM
MODEL_NAME
=
"openai-community/gpt2"
MODEL_NAME_SHORT
=
"gpt2"
BASE_MODEL_PATHS
=
[
BaseModelPath
(
name
=
MODEL_NAME
,
model_path
=
MODEL_NAME
),
BaseModelPath
(
name
=
MODEL_NAME_SHORT
,
model_path
=
MODEL_NAME_SHORT
),
]
@
dataclass
class
MockHFConfig
:
model_type
:
str
=
"any"
@
dataclass
class
MockModelConfig
:
task
=
"generate"
runner_type
=
"generate"
tokenizer
=
MODEL_NAME
trust_remote_code
=
False
tokenizer_mode
=
"auto"
max_model_len
=
100
tokenizer_revision
=
None
multimodal_config
=
MultiModalConfig
()
hf_config
=
MockHFConfig
()
logits_processor_pattern
=
None
logits_processors
:
list
[
str
]
|
None
=
None
diff_sampling_param
:
dict
|
None
=
None
allowed_local_media_path
:
str
=
""
allowed_media_domains
:
list
[
str
]
|
None
=
None
encoder_config
=
None
generation_config
:
str
=
"auto"
media_io_kwargs
:
dict
[
str
,
dict
[
str
,
Any
]]
=
field
(
default_factory
=
dict
)
skip_tokenizer_init
=
False
def
get_diff_sampling_param
(
self
):
return
self
.
diff_sampling_param
or
{}
def
_build_serving_chat
(
engine
:
AsyncLLM
)
->
OpenAIServingChat
:
models
=
OpenAIServingModels
(
engine_client
=
engine
,
base_model_paths
=
BASE_MODEL_PATHS
,
)
serving_chat
=
OpenAIServingChat
(
engine
,
models
,
response_role
=
"assistant"
,
request_logger
=
None
,
chat_template
=
None
,
chat_template_content_format
=
"auto"
,
)
async
def
_fake_process_inputs
(
request_id
,
engine_prompt
,
sampling_params
,
*
,
lora_request
,
trace_headers
,
priority
,
):
return
dict
(
engine_prompt
),
{}
async
def
_fake_preprocess_chat
(
*
args
,
**
kwargs
):
# return conversation, engine_prompts
return
(
[{
"role"
:
"user"
,
"content"
:
"Test"
}],
[{
"prompt_token_ids"
:
[
1
,
2
,
3
]}],
)
serving_chat
.
_process_inputs
=
AsyncMock
(
side_effect
=
_fake_process_inputs
)
serving_chat
.
_preprocess_chat
=
AsyncMock
(
side_effect
=
_fake_preprocess_chat
)
return
serving_chat
@
pytest
.
mark
.
asyncio
async
def
test_chat_error_non_stream
():
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
mock_engine
=
MagicMock
(
spec
=
AsyncLLM
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
mock_engine
.
errored
=
False
mock_engine
.
model_config
=
MockModelConfig
()
mock_engine
.
input_processor
=
MagicMock
()
mock_engine
.
io_processor
=
MagicMock
()
serving_chat
=
_build_serving_chat
(
mock_engine
)
completion_output
=
CompletionOutput
(
index
=
0
,
text
=
""
,
token_ids
=
[],
cumulative_logprob
=
None
,
logprobs
=
None
,
finish_reason
=
"error"
,
)
request_output
=
RequestOutput
(
request_id
=
"test-id"
,
prompt
=
"Test prompt"
,
prompt_token_ids
=
[
1
,
2
,
3
],
prompt_logprobs
=
None
,
outputs
=
[
completion_output
],
finished
=
True
,
metrics
=
None
,
lora_request
=
None
,
encoder_prompt
=
None
,
encoder_prompt_token_ids
=
None
,
)
async
def
mock_generate
(
*
args
,
**
kwargs
):
yield
request_output
mock_engine
.
generate
=
MagicMock
(
side_effect
=
mock_generate
)
request
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Test prompt"
}],
max_tokens
=
10
,
stream
=
False
,
)
response
=
await
serving_chat
.
create_chat_completion
(
request
)
assert
isinstance
(
response
,
ErrorResponse
)
assert
response
.
error
.
type
==
"InternalServerError"
assert
response
.
error
.
message
==
"Internal server error"
assert
response
.
error
.
code
==
HTTPStatus
.
INTERNAL_SERVER_ERROR
@
pytest
.
mark
.
asyncio
async
def
test_chat_error_stream
():
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
mock_engine
=
MagicMock
(
spec
=
AsyncLLM
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
mock_engine
.
errored
=
False
mock_engine
.
model_config
=
MockModelConfig
()
mock_engine
.
input_processor
=
MagicMock
()
mock_engine
.
io_processor
=
MagicMock
()
serving_chat
=
_build_serving_chat
(
mock_engine
)
completion_output_1
=
CompletionOutput
(
index
=
0
,
text
=
"Hello"
,
token_ids
=
[
100
],
cumulative_logprob
=
None
,
logprobs
=
None
,
finish_reason
=
None
,
)
request_output_1
=
RequestOutput
(
request_id
=
"test-id"
,
prompt
=
"Test prompt"
,
prompt_token_ids
=
[
1
,
2
,
3
],
prompt_logprobs
=
None
,
outputs
=
[
completion_output_1
],
finished
=
False
,
metrics
=
None
,
lora_request
=
None
,
encoder_prompt
=
None
,
encoder_prompt_token_ids
=
None
,
)
completion_output_2
=
CompletionOutput
(
index
=
0
,
text
=
"Hello"
,
token_ids
=
[
100
],
cumulative_logprob
=
None
,
logprobs
=
None
,
finish_reason
=
"error"
,
)
request_output_2
=
RequestOutput
(
request_id
=
"test-id"
,
prompt
=
"Test prompt"
,
prompt_token_ids
=
[
1
,
2
,
3
],
prompt_logprobs
=
None
,
outputs
=
[
completion_output_2
],
finished
=
True
,
metrics
=
None
,
lora_request
=
None
,
encoder_prompt
=
None
,
encoder_prompt_token_ids
=
None
,
)
async
def
mock_generate
(
*
args
,
**
kwargs
):
yield
request_output_1
yield
request_output_2
mock_engine
.
generate
=
MagicMock
(
side_effect
=
mock_generate
)
request
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
[{
"role"
:
"user"
,
"content"
:
"Test prompt"
}],
max_tokens
=
10
,
stream
=
True
,
)
response
=
await
serving_chat
.
create_chat_completion
(
request
)
chunks
=
[]
async
for
chunk
in
response
:
chunks
.
append
(
chunk
)
assert
len
(
chunks
)
>=
2
assert
any
(
"Internal server error"
in
chunk
for
chunk
in
chunks
),
(
f
"Expected error message in chunks:
{
chunks
}
"
)
assert
chunks
[
-
1
]
==
"data: [DONE]
\n\n
"
tests/entrypoints/openai/test_completion_error.py
0 → 100644
View file @
a3f8d5dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
,
field
from
http
import
HTTPStatus
from
typing
import
Any
from
unittest.mock
import
AsyncMock
,
MagicMock
import
pytest
from
vllm.config.multimodal
import
MultiModalConfig
from
vllm.entrypoints.openai.protocol
import
CompletionRequest
,
ErrorResponse
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_models
import
BaseModelPath
,
OpenAIServingModels
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.v1.engine.async_llm
import
AsyncLLM
MODEL_NAME
=
"openai-community/gpt2"
MODEL_NAME_SHORT
=
"gpt2"
BASE_MODEL_PATHS
=
[
BaseModelPath
(
name
=
MODEL_NAME
,
model_path
=
MODEL_NAME
),
BaseModelPath
(
name
=
MODEL_NAME_SHORT
,
model_path
=
MODEL_NAME_SHORT
),
]
@
dataclass
class
MockHFConfig
:
model_type
:
str
=
"any"
@
dataclass
class
MockModelConfig
:
task
=
"generate"
runner_type
=
"generate"
tokenizer
=
MODEL_NAME
trust_remote_code
=
False
tokenizer_mode
=
"auto"
max_model_len
=
100
tokenizer_revision
=
None
multimodal_config
=
MultiModalConfig
()
hf_config
=
MockHFConfig
()
logits_processor_pattern
=
None
logits_processors
:
list
[
str
]
|
None
=
None
diff_sampling_param
:
dict
|
None
=
None
allowed_local_media_path
:
str
=
""
allowed_media_domains
:
list
[
str
]
|
None
=
None
encoder_config
=
None
generation_config
:
str
=
"auto"
media_io_kwargs
:
dict
[
str
,
dict
[
str
,
Any
]]
=
field
(
default_factory
=
dict
)
skip_tokenizer_init
=
False
def
get_diff_sampling_param
(
self
):
return
self
.
diff_sampling_param
or
{}
def
_build_serving_completion
(
engine
:
AsyncLLM
)
->
OpenAIServingCompletion
:
models
=
OpenAIServingModels
(
engine_client
=
engine
,
base_model_paths
=
BASE_MODEL_PATHS
,
)
serving_completion
=
OpenAIServingCompletion
(
engine
,
models
,
request_logger
=
None
,
)
async
def
_fake_process_inputs
(
request_id
,
engine_prompt
,
sampling_params
,
*
,
lora_request
,
trace_headers
,
priority
,
):
return
dict
(
engine_prompt
),
{}
serving_completion
.
_process_inputs
=
AsyncMock
(
side_effect
=
_fake_process_inputs
)
return
serving_completion
@
pytest
.
mark
.
asyncio
async
def
test_completion_error_non_stream
():
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
mock_engine
=
MagicMock
(
spec
=
AsyncLLM
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
mock_engine
.
errored
=
False
mock_engine
.
model_config
=
MockModelConfig
()
mock_engine
.
input_processor
=
MagicMock
()
mock_engine
.
io_processor
=
MagicMock
()
serving_completion
=
_build_serving_completion
(
mock_engine
)
completion_output
=
CompletionOutput
(
index
=
0
,
text
=
""
,
token_ids
=
[],
cumulative_logprob
=
None
,
logprobs
=
None
,
finish_reason
=
"error"
,
)
request_output
=
RequestOutput
(
request_id
=
"test-id"
,
prompt
=
"Test prompt"
,
prompt_token_ids
=
[
1
,
2
,
3
],
prompt_logprobs
=
None
,
outputs
=
[
completion_output
],
finished
=
True
,
metrics
=
None
,
lora_request
=
None
,
encoder_prompt
=
None
,
encoder_prompt_token_ids
=
None
,
)
async
def
mock_generate
(
*
args
,
**
kwargs
):
yield
request_output
mock_engine
.
generate
=
MagicMock
(
side_effect
=
mock_generate
)
request
=
CompletionRequest
(
model
=
MODEL_NAME
,
prompt
=
"Test prompt"
,
max_tokens
=
10
,
stream
=
False
,
)
response
=
await
serving_completion
.
create_completion
(
request
)
assert
isinstance
(
response
,
ErrorResponse
)
assert
response
.
error
.
type
==
"InternalServerError"
assert
response
.
error
.
message
==
"Internal server error"
assert
response
.
error
.
code
==
HTTPStatus
.
INTERNAL_SERVER_ERROR
@
pytest
.
mark
.
asyncio
async
def
test_completion_error_stream
():
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
mock_engine
=
MagicMock
(
spec
=
AsyncLLM
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
mock_engine
.
errored
=
False
mock_engine
.
model_config
=
MockModelConfig
()
mock_engine
.
input_processor
=
MagicMock
()
mock_engine
.
io_processor
=
MagicMock
()
serving_completion
=
_build_serving_completion
(
mock_engine
)
completion_output_1
=
CompletionOutput
(
index
=
0
,
text
=
"Hello"
,
token_ids
=
[
100
],
cumulative_logprob
=
None
,
logprobs
=
None
,
finish_reason
=
None
,
)
request_output_1
=
RequestOutput
(
request_id
=
"test-id"
,
prompt
=
"Test prompt"
,
prompt_token_ids
=
[
1
,
2
,
3
],
prompt_logprobs
=
None
,
outputs
=
[
completion_output_1
],
finished
=
False
,
metrics
=
None
,
lora_request
=
None
,
encoder_prompt
=
None
,
encoder_prompt_token_ids
=
None
,
)
completion_output_2
=
CompletionOutput
(
index
=
0
,
text
=
"Hello"
,
token_ids
=
[
100
],
cumulative_logprob
=
None
,
logprobs
=
None
,
finish_reason
=
"error"
,
)
request_output_2
=
RequestOutput
(
request_id
=
"test-id"
,
prompt
=
"Test prompt"
,
prompt_token_ids
=
[
1
,
2
,
3
],
prompt_logprobs
=
None
,
outputs
=
[
completion_output_2
],
finished
=
True
,
metrics
=
None
,
lora_request
=
None
,
encoder_prompt
=
None
,
encoder_prompt_token_ids
=
None
,
)
async
def
mock_generate
(
*
args
,
**
kwargs
):
yield
request_output_1
yield
request_output_2
mock_engine
.
generate
=
MagicMock
(
side_effect
=
mock_generate
)
request
=
CompletionRequest
(
model
=
MODEL_NAME
,
prompt
=
"Test prompt"
,
max_tokens
=
10
,
stream
=
True
,
)
response
=
await
serving_completion
.
create_completion
(
request
)
chunks
=
[]
async
for
chunk
in
response
:
chunks
.
append
(
chunk
)
assert
len
(
chunks
)
>=
2
assert
any
(
"Internal server error"
in
chunk
for
chunk
in
chunks
),
(
f
"Expected error message in chunks:
{
chunks
}
"
)
assert
chunks
[
-
1
]
==
"data: [DONE]
\n\n
"
tests/entrypoints/openai/test_messages.py
View file @
a3f8d5dd
...
@@ -79,9 +79,12 @@ async def test_anthropic_streaming(client: anthropic.AsyncAnthropic):
...
@@ -79,9 +79,12 @@ async def test_anthropic_streaming(client: anthropic.AsyncAnthropic):
assert
chunk_count
>
0
assert
chunk_count
>
0
assert
first_chunk
is
not
None
,
"message_start chunk was never observed"
assert
first_chunk
is
not
None
,
"message_start chunk was never observed"
assert
first_chunk
.
usage
is
not
None
,
"first chunk should include usage stats"
assert
first_chunk
.
message
is
not
None
,
"first chunk should include message"
assert
first_chunk
.
usage
[
"output_tokens"
]
==
0
assert
first_chunk
.
message
.
usage
is
not
None
,
(
assert
first_chunk
.
usage
[
"input_tokens"
]
>
5
"first chunk should include usage stats"
)
assert
first_chunk
.
message
.
usage
.
output_tokens
==
0
assert
first_chunk
.
message
.
usage
.
input_tokens
>
5
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
...
...
tests/entrypoints/openai/test_responses_error.py
0 → 100644
View file @
a3f8d5dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
http
import
HTTPStatus
from
unittest.mock
import
MagicMock
import
pytest
from
vllm.entrypoints.openai.protocol
import
ErrorResponse
from
vllm.entrypoints.openai.serving_engine
import
GenerationError
,
OpenAIServing
@
pytest
.
mark
.
asyncio
async
def
test_raise_if_error_raises_generation_error
():
"""test _raise_if_error raises GenerationError"""
# create a minimal OpenAIServing instance
mock_engine
=
MagicMock
()
mock_engine
.
model_config
=
MagicMock
()
mock_engine
.
model_config
.
max_model_len
=
100
mock_models
=
MagicMock
()
serving
=
OpenAIServing
(
engine_client
=
mock_engine
,
models
=
mock_models
,
request_logger
=
None
,
)
# test that error finish_reason raises GenerationError
with
pytest
.
raises
(
GenerationError
)
as
exc_info
:
serving
.
_raise_if_error
(
"error"
,
"test-request-id"
)
assert
str
(
exc_info
.
value
)
==
"Internal server error"
assert
exc_info
.
value
.
status_code
==
HTTPStatus
.
INTERNAL_SERVER_ERROR
# test that other finish_reasons don't raise
serving
.
_raise_if_error
(
"stop"
,
"test-request-id"
)
# should not raise
serving
.
_raise_if_error
(
"length"
,
"test-request-id"
)
# should not raise
serving
.
_raise_if_error
(
None
,
"test-request-id"
)
# should not raise
@
pytest
.
mark
.
asyncio
async
def
test_convert_generation_error_to_response
():
"""test _convert_generation_error_to_response creates proper ErrorResponse"""
mock_engine
=
MagicMock
()
mock_engine
.
model_config
=
MagicMock
()
mock_engine
.
model_config
.
max_model_len
=
100
mock_models
=
MagicMock
()
serving
=
OpenAIServing
(
engine_client
=
mock_engine
,
models
=
mock_models
,
request_logger
=
None
,
)
# create a GenerationError
gen_error
=
GenerationError
(
"Internal server error"
)
# convert to ErrorResponse
error_response
=
serving
.
_convert_generation_error_to_response
(
gen_error
)
assert
isinstance
(
error_response
,
ErrorResponse
)
assert
error_response
.
error
.
type
==
"InternalServerError"
assert
error_response
.
error
.
message
==
"Internal server error"
assert
error_response
.
error
.
code
==
HTTPStatus
.
INTERNAL_SERVER_ERROR
@
pytest
.
mark
.
asyncio
async
def
test_convert_generation_error_to_streaming_response
():
"""test _convert_generation_error_to_streaming_response output"""
mock_engine
=
MagicMock
()
mock_engine
.
model_config
=
MagicMock
()
mock_engine
.
model_config
.
max_model_len
=
100
mock_models
=
MagicMock
()
serving
=
OpenAIServing
(
engine_client
=
mock_engine
,
models
=
mock_models
,
request_logger
=
None
,
)
# create a GenerationError
gen_error
=
GenerationError
(
"Internal server error"
)
# convert to streaming error response
error_json
=
serving
.
_convert_generation_error_to_streaming_response
(
gen_error
)
assert
isinstance
(
error_json
,
str
)
assert
"Internal server error"
in
error_json
assert
"InternalServerError"
in
error_json
tests/entrypoints/openai/test_serving_chat.py
View file @
a3f8d5dd
...
@@ -11,13 +11,25 @@ import pytest_asyncio
...
@@ -11,13 +11,25 @@ import pytest_asyncio
from
openai
import
OpenAI
from
openai
import
OpenAI
from
vllm.config.multimodal
import
MultiModalConfig
from
vllm.config.multimodal
import
MultiModalConfig
from
vllm.entrypoints.openai.protocol
import
ChatCompletionRequest
from
vllm.entrypoints.openai.parser.harmony_utils
import
get_encoding
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
ChatCompletionResponse
,
RequestResponseMetadata
,
)
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_models
import
BaseModelPath
,
OpenAIServingModels
from
vllm.entrypoints.openai.serving_models
import
BaseModelPath
,
OpenAIServingModels
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.tokenizers
import
get_tokenizer
from
vllm.tokenizers
import
get_tokenizer
from
vllm.tool_parsers
import
ToolParserManager
from
vllm.v1.engine.async_llm
import
AsyncLLM
from
vllm.v1.engine.async_llm
import
AsyncLLM
from
...utils
import
RemoteOpenAIServer
from
...utils
import
RemoteOpenAIServer
from
.utils
import
(
accumulate_streaming_response
,
verify_chat_response
,
verify_harmony_messages
,
)
GPT_OSS_MODEL_NAME
=
"openai/gpt-oss-20b"
GPT_OSS_MODEL_NAME
=
"openai/gpt-oss-20b"
...
@@ -728,3 +740,635 @@ async def test_serving_chat_data_parallel_rank_extraction():
...
@@ -728,3 +740,635 @@ async def test_serving_chat_data_parallel_rank_extraction():
# Verify that data_parallel_rank defaults to None
# Verify that data_parallel_rank defaults to None
assert
"data_parallel_rank"
in
mock_engine
.
generate
.
call_args
.
kwargs
assert
"data_parallel_rank"
in
mock_engine
.
generate
.
call_args
.
kwargs
assert
mock_engine
.
generate
.
call_args
.
kwargs
[
"data_parallel_rank"
]
is
None
assert
mock_engine
.
generate
.
call_args
.
kwargs
[
"data_parallel_rank"
]
is
None
class
TestServingChatWithHarmony
:
"""
These tests ensure Chat Completion requests are being properly converted into
Harmony messages and Harmony response messages back into Chat Completion responses.
These tests are not exhaustive, but each one was created to cover a specific case
that we got wrong but is now fixed.
Any changes to the tests and their expectations may result in changes to the
accuracy of model prompting and responses generated. It is suggested to run
an evaluation or benchmarking suite (such as bfcl multi_turn) to understand
any impact of changes in how we prompt Harmony models.
"""
@
pytest
.
fixture
(
params
=
[
False
,
True
],
ids
=
[
"non_streaming"
,
"streaming"
])
def
stream
(
self
,
request
)
->
bool
:
"""Parameterize tests to run in both non-streaming and streaming modes."""
return
request
.
param
@
pytest
.
fixture
()
def
mock_engine
(
self
)
->
AsyncLLM
:
mock_engine
=
MagicMock
(
spec
=
AsyncLLM
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
mock_engine
.
errored
=
False
mock_engine
.
model_config
=
MockModelConfig
()
mock_engine
.
input_processor
=
MagicMock
()
mock_engine
.
io_processor
=
MagicMock
()
return
mock_engine
@
pytest
.
fixture
()
def
serving_chat
(
self
,
mock_engine
)
->
OpenAIServingChat
:
chat
=
_build_serving_chat
(
mock_engine
)
chat
.
use_harmony
=
True
chat
.
tool_parser
=
ToolParserManager
.
get_tool_parser
(
"openai"
)
return
chat
def
mock_request_output_from_req_and_token_ids
(
self
,
req
:
ChatCompletionRequest
,
token_ids
:
list
[
int
],
finished
:
bool
=
False
)
->
RequestOutput
:
# Our tests don't use most fields, so just get the token ids correct
completion_output
=
CompletionOutput
(
index
=
0
,
text
=
""
,
token_ids
=
token_ids
,
cumulative_logprob
=
0.0
,
logprobs
=
None
,
)
return
RequestOutput
(
request_id
=
req
.
request_id
,
prompt
=
[],
prompt_token_ids
=
[],
prompt_logprobs
=
None
,
outputs
=
[
completion_output
],
finished
=
finished
,
)
@
pytest
.
fixture
def
weather_tools
(
self
)
->
list
[
dict
[
str
,
Any
]]:
return
[
{
"type"
:
"function"
,
"function"
:
{
"name"
:
"get_weather"
,
"description"
:
"Get the weather in a given location"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"location"
:
{
"type"
:
"string"
},
},
"required"
:
[
"location"
],
},
},
},
]
@
pytest
.
fixture
def
weather_messages_start
(
self
)
->
list
[
dict
[
str
,
Any
]]:
return
[
{
"role"
:
"user"
,
"content"
:
"What's the weather like in Paris today?"
,
},
]
async
def
generate_response_from_harmony_str
(
self
,
serving_chat
:
OpenAIServingChat
,
req
:
ChatCompletionRequest
,
harmony_str
:
str
,
stream
:
bool
=
False
,
)
->
ChatCompletionResponse
:
harmony_token_ids
=
get_encoding
().
encode
(
harmony_str
,
allowed_special
=
"all"
)
async
def
result_generator
():
if
stream
:
for
token_id
in
harmony_token_ids
:
yield
self
.
mock_request_output_from_req_and_token_ids
(
req
,
[
token_id
]
)
yield
self
.
mock_request_output_from_req_and_token_ids
(
req
,
[],
finished
=
True
)
else
:
yield
self
.
mock_request_output_from_req_and_token_ids
(
req
,
harmony_token_ids
,
finished
=
True
)
generator_func
=
(
serving_chat
.
chat_completion_stream_generator
if
stream
else
serving_chat
.
chat_completion_full_generator
)
result
=
generator_func
(
request
=
req
,
result_generator
=
result_generator
(),
request_id
=
req
.
request_id
,
model_name
=
req
.
model
,
conversation
=
[],
tokenizer
=
get_tokenizer
(
req
.
model
),
request_metadata
=
RequestResponseMetadata
(
request_id
=
req
.
request_id
,
model_name
=
req
.
model
,
),
)
if
stream
:
return
await
accumulate_streaming_response
(
result
)
return
await
result
@
pytest
.
mark
.
asyncio
async
def
test_simple_chat
(
self
,
serving_chat
,
stream
):
messages
=
[{
"role"
:
"user"
,
"content"
:
"what is 1+1?"
}]
# Test the Harmony messages for the first turn's input
req
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
messages
)
input_messages
,
_
=
serving_chat
.
_make_request_with_harmony
(
req
)
verify_harmony_messages
(
input_messages
,
[
{
"role"
:
"system"
},
{
"role"
:
"developer"
},
{
"role"
:
"user"
,
"content"
:
messages
[
0
][
"content"
]},
],
)
# Test the Chat Completion response for the first turn's output
reasoning_str
=
"We need to think really hard about this."
final_str
=
"The answer is 2."
response_str
=
(
f
"<|channel|>analysis<|message|>
{
reasoning_str
}
<|end|>"
f
"<|start|>assistant<|channel|>final<|message|>
{
final_str
}
<|end|>"
)
response
=
await
self
.
generate_response_from_harmony_str
(
serving_chat
,
req
,
response_str
,
stream
=
stream
)
verify_chat_response
(
response
,
content
=
final_str
,
reasoning
=
reasoning_str
)
# Add the output messages from the first turn as input to the second turn
for
choice
in
response
.
choices
:
messages
.
append
(
choice
.
message
.
model_dump
(
exclude_none
=
True
))
# Test the Harmony messages for the second turn's input
req_2
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
messages
)
input_messages_2
,
_
=
serving_chat
.
_make_request_with_harmony
(
req_2
)
verify_harmony_messages
(
input_messages_2
,
[
{
"role"
:
"system"
},
{
"role"
:
"developer"
},
{
"role"
:
"user"
},
# The analysis message should be dropped on subsequent inputs because
# of the subsequent assistant message to the final channel.
{
"role"
:
"assistant"
,
"channel"
:
"final"
,
"content"
:
final_str
},
],
)
@
pytest
.
mark
.
asyncio
async
def
test_tool_call_response_with_content
(
self
,
serving_chat
,
stream
,
weather_tools
,
weather_messages_start
):
tools
=
weather_tools
messages
=
list
(
weather_messages_start
)
# Test the Harmony messages for the first turn's input
req
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
messages
,
tools
=
tools
)
input_messages
,
_
=
serving_chat
.
_make_request_with_harmony
(
req
)
verify_harmony_messages
(
input_messages
,
[
{
"role"
:
"system"
},
{
"role"
:
"developer"
,
"tool_definitions"
:
[
"get_weather"
]},
{
"role"
:
"user"
,
"content"
:
messages
[
0
][
"content"
]},
],
)
# Test the Chat Completion response for the first turn's output
commentary_str
=
"We'll call get_weather."
tool_args_str
=
'{"location": "Paris"}'
response_str
=
(
f
"<|channel|>commentary<|message|>
{
commentary_str
}
<|end|>"
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
f
"<|constrain|>json<|message|>
{
tool_args_str
}
<|call|>"
)
response
=
await
self
.
generate_response_from_harmony_str
(
serving_chat
,
req
,
response_str
,
stream
=
stream
)
verify_chat_response
(
response
,
content
=
commentary_str
,
tool_calls
=
[(
"get_weather"
,
tool_args_str
)],
)
tool_call
=
response
.
choices
[
0
].
message
.
tool_calls
[
0
]
# Add the output messages from the first turn as input to the second turn
for
choice
in
response
.
choices
:
messages
.
append
(
choice
.
message
.
model_dump
(
exclude_none
=
True
))
# Add our tool output message
messages
.
append
(
{
"role"
:
"tool"
,
"tool_call_id"
:
tool_call
.
id
,
"content"
:
"20 degrees Celsius"
,
},
)
# Test the Harmony messages for the second turn's input
req_2
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
messages
)
input_messages_2
,
_
=
serving_chat
.
_make_request_with_harmony
(
req_2
)
verify_harmony_messages
(
input_messages_2
,
[
{
"role"
:
"system"
},
{
"role"
:
"developer"
},
{
"role"
:
"user"
},
{
"role"
:
"assistant"
,
"channel"
:
"commentary"
,
"content"
:
commentary_str
,
},
{
"role"
:
"assistant"
,
"channel"
:
"commentary"
,
"recipient"
:
"functions.get_weather"
,
"content"
:
tool_args_str
,
},
{
"role"
:
"tool"
,
"author_name"
:
"functions.get_weather"
,
"channel"
:
"commentary"
,
"recipient"
:
"assistant"
,
"content"
:
"20 degrees Celsius"
,
},
],
)
@
pytest
.
mark
.
asyncio
async
def
test_tools_and_reasoning
(
self
,
serving_chat
,
stream
,
weather_tools
,
weather_messages_start
):
tools
=
weather_tools
messages
=
list
(
weather_messages_start
)
# Test the Harmony messages for the first turn's input
req
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
messages
,
tools
=
tools
)
input_messages
,
_
=
serving_chat
.
_make_request_with_harmony
(
req
)
verify_harmony_messages
(
input_messages
,
[
{
"role"
:
"system"
},
{
"role"
:
"developer"
,
"tool_definitions"
:
[
"get_weather"
]},
{
"role"
:
"user"
,
"content"
:
messages
[
0
][
"content"
]},
],
)
# Test the Chat Completion response for the first turn's output
reasoning_str
=
"I'll call get_weather."
tool_args_str
=
'{"location": "Paris"}'
response_str
=
(
f
"<|channel|>analysis<|message|>
{
reasoning_str
}
<|end|>"
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
f
"<|constrain|>json<|message|>
{
tool_args_str
}
<|call|>"
)
response
=
await
self
.
generate_response_from_harmony_str
(
serving_chat
,
req
,
response_str
,
stream
=
stream
)
verify_chat_response
(
response
,
reasoning
=
reasoning_str
,
tool_calls
=
[(
"get_weather"
,
tool_args_str
)],
)
tool_call
=
response
.
choices
[
0
].
message
.
tool_calls
[
0
]
# Add the output messages from the first turn as input to the second turn
for
choice
in
response
.
choices
:
messages
.
append
(
choice
.
message
.
model_dump
(
exclude_none
=
True
))
# Add our tool output message
messages
.
append
(
{
"role"
:
"tool"
,
"tool_call_id"
:
tool_call
.
id
,
"content"
:
"20 degrees Celsius"
,
},
)
# Test the Harmony messages for the second turn's input
req_2
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
messages
)
input_messages_2
,
_
=
serving_chat
.
_make_request_with_harmony
(
req_2
)
verify_harmony_messages
(
input_messages_2
,
[
{
"role"
:
"system"
},
{
"role"
:
"developer"
},
{
"role"
:
"user"
},
{
"role"
:
"assistant"
,
"channel"
:
"analysis"
,
"content"
:
reasoning_str
,
},
{
"role"
:
"assistant"
,
"channel"
:
"commentary"
,
"recipient"
:
"functions.get_weather"
,
"content"
:
tool_args_str
,
},
{
"role"
:
"tool"
,
"author_name"
:
"functions.get_weather"
,
"channel"
:
"commentary"
,
"recipient"
:
"assistant"
,
"content"
:
"20 degrees Celsius"
,
},
],
)
@
pytest
.
mark
.
asyncio
async
def
test_multi_turn_tools_and_reasoning
(
self
,
serving_chat
,
stream
,
weather_tools
,
weather_messages_start
):
tools
=
weather_tools
messages
=
list
(
weather_messages_start
)
# Test the Harmony messages for the first turn's input
req
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
messages
,
tools
=
tools
)
input_messages
,
_
=
serving_chat
.
_make_request_with_harmony
(
req
)
verify_harmony_messages
(
input_messages
,
[
{
"role"
:
"system"
},
{
"role"
:
"developer"
,
"tool_definitions"
:
[
"get_weather"
]},
{
"role"
:
"user"
,
"content"
:
messages
[
0
][
"content"
]},
],
)
# Test the Chat Completion response for the first turn's output
reasoning_str
=
"I'll call get_weather."
paris_tool_args_str
=
'{"location": "Paris"}'
response_str
=
(
f
"<|channel|>analysis<|message|>
{
reasoning_str
}
<|end|>"
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
f
"<|constrain|>json<|message|>
{
paris_tool_args_str
}
<|call|>"
)
response
=
await
self
.
generate_response_from_harmony_str
(
serving_chat
,
req
,
response_str
,
stream
=
stream
)
verify_chat_response
(
response
,
reasoning
=
reasoning_str
,
tool_calls
=
[(
"get_weather"
,
paris_tool_args_str
)],
)
tool_call
=
response
.
choices
[
0
].
message
.
tool_calls
[
0
]
# Add the output messages from the first turn as input to the second turn
for
choice
in
response
.
choices
:
messages
.
append
(
choice
.
message
.
model_dump
(
exclude_none
=
True
))
# Add our tool output message
messages
.
append
(
{
"role"
:
"tool"
,
"tool_call_id"
:
tool_call
.
id
,
"content"
:
"20 degrees Celsius"
,
},
)
# Test the Harmony messages for the second turn's input
req_2
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
messages
)
input_messages_2
,
_
=
serving_chat
.
_make_request_with_harmony
(
req_2
)
verify_harmony_messages
(
input_messages_2
,
[
{
"role"
:
"system"
},
{
"role"
:
"developer"
},
{
"role"
:
"user"
},
{
"role"
:
"assistant"
,
"channel"
:
"analysis"
,
"content"
:
reasoning_str
,
},
{
"role"
:
"assistant"
,
"channel"
:
"commentary"
,
"recipient"
:
"functions.get_weather"
,
"content"
:
paris_tool_args_str
,
},
{
"role"
:
"tool"
,
"author_name"
:
"functions.get_weather"
,
"channel"
:
"commentary"
,
"recipient"
:
"assistant"
,
"content"
:
"20 degrees Celsius"
,
},
],
)
# Test the Chat Completion response for the second turn's output
paris_weather_str
=
"The weather in Paris today is 20 degrees Celsius."
response_str
=
f
"<|channel|>final<|message|>
{
paris_weather_str
}
<|end|>"
response_2
=
await
self
.
generate_response_from_harmony_str
(
serving_chat
,
req_2
,
response_str
,
stream
=
stream
)
verify_chat_response
(
response_2
,
content
=
paris_weather_str
)
# Add the output messages from the second turn as input to the third turn
for
choice
in
response_2
.
choices
:
messages
.
append
(
choice
.
message
.
model_dump
(
exclude_none
=
True
))
# Add a new user message for the third turn
messages
.
append
(
{
"role"
:
"user"
,
"content"
:
"What's the weather like in Boston today?"
,
},
)
# Test the Harmony messages for the third turn's input
req_3
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
messages
)
input_messages_3
,
_
=
serving_chat
.
_make_request_with_harmony
(
req_3
)
verify_harmony_messages
(
input_messages_3
,
[
{
"role"
:
"system"
},
{
"role"
:
"developer"
},
{
"role"
:
"user"
},
{
"role"
:
"assistant"
,
"channel"
:
"commentary"
,
"recipient"
:
"functions.get_weather"
,
"content"
:
paris_tool_args_str
,
},
{
"role"
:
"tool"
,
"author_name"
:
"functions.get_weather"
,
"channel"
:
"commentary"
,
"recipient"
:
"assistant"
,
"content"
:
"20 degrees Celsius"
,
},
{
"role"
:
"assistant"
,
"channel"
:
"final"
,
"content"
:
paris_weather_str
,
},
{
"role"
:
"user"
,
"content"
:
messages
[
-
1
][
"content"
]},
],
)
# Test the Chat Completion response for the third turn's output
reasoning_str
=
"I'll call get_weather."
boston_tool_args_str
=
'{"location": "Boston"}'
response_str
=
(
f
"<|channel|>analysis<|message|>
{
reasoning_str
}
<|end|>"
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
f
"<|constrain|>json<|message|>
{
boston_tool_args_str
}
<|call|>"
)
response_3
=
await
self
.
generate_response_from_harmony_str
(
serving_chat
,
req
,
response_str
,
stream
=
stream
)
verify_chat_response
(
response_3
,
reasoning
=
reasoning_str
,
tool_calls
=
[(
"get_weather"
,
boston_tool_args_str
)],
)
tool_call
=
response_3
.
choices
[
0
].
message
.
tool_calls
[
0
]
# Add the output messages from the third turn as input to the fourth turn
for
choice
in
response_3
.
choices
:
messages
.
append
(
choice
.
message
.
model_dump
(
exclude_none
=
True
))
# Add our tool output message
messages
.
append
(
{
"role"
:
"tool"
,
"tool_call_id"
:
tool_call
.
id
,
"content"
:
"10 degrees Celsius"
,
},
)
# Test the Harmony messages for the fourth turn's input
req_4
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
messages
)
input_messages_4
,
_
=
serving_chat
.
_make_request_with_harmony
(
req_4
)
verify_harmony_messages
(
input_messages_4
,
[
{
"role"
:
"system"
},
{
"role"
:
"developer"
},
{
"role"
:
"user"
},
{
"role"
:
"assistant"
},
{
"role"
:
"tool"
},
{
"role"
:
"assistant"
,
"channel"
:
"final"
,
},
{
"role"
:
"user"
},
{
"role"
:
"assistant"
,
"channel"
:
"analysis"
,
"content"
:
reasoning_str
,
},
{
"role"
:
"assistant"
,
"channel"
:
"commentary"
,
"recipient"
:
"functions.get_weather"
,
"content"
:
boston_tool_args_str
,
},
{
"role"
:
"tool"
,
"author_name"
:
"functions.get_weather"
,
"channel"
:
"commentary"
,
"recipient"
:
"assistant"
,
"content"
:
"10 degrees Celsius"
,
},
],
)
@
pytest
.
mark
.
asyncio
async
def
test_non_tool_reasoning
(
self
,
serving_chat
):
messages
:
list
[
dict
[
str
,
Any
]]
=
[
{
"role"
:
"user"
,
"content"
:
"What's 2+2?"
,
},
{
"role"
:
"assistant"
,
"reasoning"
:
"Adding 2 and 2 is easy. The result is 4."
,
"content"
:
"4"
,
},
]
req
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
messages
)
input_messages
,
_
=
serving_chat
.
_make_request_with_harmony
(
req
)
verify_harmony_messages
(
input_messages
,
[
{
"role"
:
"system"
},
{
"role"
:
"developer"
},
{
"role"
:
"user"
,
"content"
:
messages
[
0
][
"content"
]},
# The reasoning that would have resulted in an analysis message is
# dropped because of a later assistant message to the final channel.
{
"role"
:
"assistant"
,
"channel"
:
"final"
,
"content"
:
messages
[
1
][
"content"
],
},
],
)
@
pytest
.
mark
.
asyncio
async
def
test_non_tool_reasoning_empty_content
(
self
,
serving_chat
):
messages
:
list
[
dict
[
str
,
Any
]]
=
[
{
"role"
:
"user"
,
"content"
:
"What's 2+2?"
,
},
{
"role"
:
"assistant"
,
"reasoning"
:
"Adding 2 and 2 is easy. The result is 4."
,
"content"
:
""
,
},
]
req
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
messages
)
input_messages
,
_
=
serving_chat
.
_make_request_with_harmony
(
req
)
verify_harmony_messages
(
input_messages
,
[
{
"role"
:
"system"
},
{
"role"
:
"developer"
},
{
"role"
:
"user"
,
"content"
:
messages
[
0
][
"content"
]},
{
"role"
:
"assistant"
,
"channel"
:
"analysis"
,
"content"
:
messages
[
1
][
"reasoning"
],
},
],
)
@
pytest
.
mark
.
asyncio
async
def
test_non_tool_reasoning_empty_content_list
(
self
,
serving_chat
):
messages
:
list
[
dict
[
str
,
Any
]]
=
[
{
"role"
:
"user"
,
"content"
:
"What's 2+2?"
,
},
{
"role"
:
"assistant"
,
"reasoning"
:
"Adding 2 and 2 is easy. The result is 4."
,
"content"
:
[],
},
]
req
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
messages
)
input_messages
,
_
=
serving_chat
.
_make_request_with_harmony
(
req
)
verify_harmony_messages
(
input_messages
,
[
{
"role"
:
"system"
},
{
"role"
:
"developer"
},
{
"role"
:
"user"
,
"content"
:
messages
[
0
][
"content"
]},
{
"role"
:
"assistant"
,
"channel"
:
"analysis"
,
"content"
:
messages
[
1
][
"reasoning"
],
},
],
)
tests/entrypoints/openai/test_serving_engine.py
View file @
a3f8d5dd
...
@@ -10,7 +10,7 @@ import pytest
...
@@ -10,7 +10,7 @@ import pytest
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.tokenizers
import
MistralTokenizer
from
vllm.tokenizers
.mistral
import
MistralTokenizer
@
pytest
.
fixture
()
@
pytest
.
fixture
()
...
...
tests/entrypoints/openai/test_serving_responses.py
View file @
a3f8d5dd
...
@@ -21,7 +21,7 @@ from vllm.entrypoints.openai.serving_responses import (
...
@@ -21,7 +21,7 @@ from vllm.entrypoints.openai.serving_responses import (
extract_tool_types
,
extract_tool_types
,
)
)
from
vllm.entrypoints.tool_server
import
ToolServer
from
vllm.entrypoints.tool_server
import
ToolServer
from
vllm.inputs.data
import
TokensPrompt
as
EngineTokensPrompt
from
vllm.inputs.data
import
TokensPrompt
class
MockConversationContext
(
ConversationContext
):
class
MockConversationContext
(
ConversationContext
):
...
@@ -237,7 +237,7 @@ class TestValidateGeneratorInput:
...
@@ -237,7 +237,7 @@ class TestValidateGeneratorInput:
"""Test _validate_generator_input with valid prompt length"""
"""Test _validate_generator_input with valid prompt length"""
# Create an engine prompt with valid length (less than max_model_len)
# Create an engine prompt with valid length (less than max_model_len)
valid_prompt_token_ids
=
list
(
range
(
5
))
# 5 tokens < 100 max_model_len
valid_prompt_token_ids
=
list
(
range
(
5
))
# 5 tokens < 100 max_model_len
engine_prompt
=
Engine
TokensPrompt
(
prompt_token_ids
=
valid_prompt_token_ids
)
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
valid_prompt_token_ids
)
# Call the method
# Call the method
result
=
serving_responses_instance
.
_validate_generator_input
(
engine_prompt
)
result
=
serving_responses_instance
.
_validate_generator_input
(
engine_prompt
)
...
@@ -247,7 +247,7 @@ class TestValidateGeneratorInput:
...
@@ -247,7 +247,7 @@ class TestValidateGeneratorInput:
# create an invalid engine prompt
# create an invalid engine prompt
invalid_prompt_token_ids
=
list
(
range
(
200
))
# 100 tokens >= 100 max_model_len
invalid_prompt_token_ids
=
list
(
range
(
200
))
# 100 tokens >= 100 max_model_len
engine_prompt
=
Engine
TokensPrompt
(
prompt_token_ids
=
invalid_prompt_token_ids
)
engine_prompt
=
TokensPrompt
(
prompt_token_ids
=
invalid_prompt_token_ids
)
# Call the method
# Call the method
result
=
serving_responses_instance
.
_validate_generator_input
(
engine_prompt
)
result
=
serving_responses_instance
.
_validate_generator_input
(
engine_prompt
)
...
...
tests/entrypoints/openai/test_sparse_tensor_validation.py
0 → 100644
View file @
a3f8d5dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Sparse tensor validation in embedding APIs.
Tests verify that malicious sparse tensors are rejected before they can trigger
out-of-bounds memory writes during to_dense() operations.
"""
import
base64
import
io
import
pytest
import
torch
from
vllm.entrypoints.renderer
import
CompletionRenderer
from
vllm.multimodal.audio
import
AudioEmbeddingMediaIO
from
vllm.multimodal.image
import
ImageEmbeddingMediaIO
def
_encode_tensor
(
tensor
:
torch
.
Tensor
)
->
bytes
:
"""Helper to encode a tensor as base64 bytes."""
buffer
=
io
.
BytesIO
()
torch
.
save
(
tensor
,
buffer
)
buffer
.
seek
(
0
)
return
base64
.
b64encode
(
buffer
.
read
())
def
_create_malicious_sparse_tensor
()
->
torch
.
Tensor
:
"""
Create a malicious sparse COO tensor with out-of-bounds indices.
This tensor has indices that point beyond the declared shape, which would
cause an out-of-bounds write when converted to dense format without
validation.
"""
# Create a 3x3 sparse tensor but with indices pointing to (10, 10)
indices
=
torch
.
tensor
([[
10
],
[
10
]])
# Out of bounds for 3x3 shape
values
=
torch
.
tensor
([
1.0
])
shape
=
(
3
,
3
)
# Create sparse tensor (this will be invalid)
sparse_tensor
=
torch
.
sparse_coo_tensor
(
indices
,
values
,
shape
,
dtype
=
torch
.
float32
)
return
sparse_tensor
def
_create_valid_sparse_tensor
()
->
torch
.
Tensor
:
"""Create a valid sparse COO tensor for baseline testing."""
indices
=
torch
.
tensor
([[
0
,
1
,
2
],
[
0
,
1
,
2
]])
values
=
torch
.
tensor
([
1.0
,
2.0
,
3.0
])
shape
=
(
3
,
3
)
sparse_tensor
=
torch
.
sparse_coo_tensor
(
indices
,
values
,
shape
,
dtype
=
torch
.
float32
)
return
sparse_tensor
def
_create_valid_dense_tensor
()
->
torch
.
Tensor
:
"""Create a valid dense tensor for baseline testing."""
return
torch
.
randn
(
10
,
768
,
dtype
=
torch
.
float32
)
# (seq_len, hidden_size)
class
TestPromptEmbedsValidation
:
"""Test sparse tensor validation in prompt embeddings (Completions API)."""
def
test_valid_dense_tensor_accepted
(
self
,
model_config
):
"""Baseline: Valid dense tensors should work normally."""
renderer
=
CompletionRenderer
(
model_config
)
valid_tensor
=
_create_valid_dense_tensor
()
encoded
=
_encode_tensor
(
valid_tensor
)
# Should not raise any exception
result
=
renderer
.
load_prompt_embeds
(
encoded
)
assert
len
(
result
)
==
1
assert
result
[
0
][
"prompt_embeds"
].
shape
==
valid_tensor
.
shape
def
test_valid_sparse_tensor_accepted
(
self
):
"""Baseline: Valid sparse tensors should load successfully."""
io_handler
=
ImageEmbeddingMediaIO
()
valid_sparse
=
_create_valid_sparse_tensor
()
encoded
=
_encode_tensor
(
valid_sparse
)
# Should not raise any exception (sparse tensors remain sparse)
result
=
io_handler
.
load_base64
(
""
,
encoded
.
decode
(
"utf-8"
))
assert
result
.
shape
==
valid_sparse
.
shape
def
test_malicious_sparse_tensor_rejected
(
self
,
model_config
):
"""Security: Malicious sparse tensors should be rejected."""
renderer
=
CompletionRenderer
(
model_config
)
malicious_tensor
=
_create_malicious_sparse_tensor
()
encoded
=
_encode_tensor
(
malicious_tensor
)
# Should raise RuntimeError due to invalid sparse tensor
with
pytest
.
raises
((
RuntimeError
,
ValueError
))
as
exc_info
:
renderer
.
load_prompt_embeds
(
encoded
)
# Error should indicate sparse tensor validation failure
error_msg
=
str
(
exc_info
.
value
).
lower
()
assert
"sparse"
in
error_msg
or
"index"
in
error_msg
or
"bounds"
in
error_msg
def
test_extremely_large_indices_rejected
(
self
,
model_config
):
"""Security: Sparse tensors with extremely large indices should be rejected."""
renderer
=
CompletionRenderer
(
model_config
)
# Create tensor with indices far beyond reasonable bounds
indices
=
torch
.
tensor
([[
999999
],
[
999999
]])
values
=
torch
.
tensor
([
1.0
])
shape
=
(
10
,
10
)
malicious_tensor
=
torch
.
sparse_coo_tensor
(
indices
,
values
,
shape
,
dtype
=
torch
.
float32
)
encoded
=
_encode_tensor
(
malicious_tensor
)
with
pytest
.
raises
((
RuntimeError
,
ValueError
)):
renderer
.
load_prompt_embeds
(
encoded
)
def
test_negative_indices_rejected
(
self
,
model_config
):
"""Security: Sparse tensors with negative indices should be rejected."""
renderer
=
CompletionRenderer
(
model_config
)
# Create tensor with negative indices
indices
=
torch
.
tensor
([[
-
1
],
[
-
1
]])
values
=
torch
.
tensor
([
1.0
])
shape
=
(
10
,
10
)
malicious_tensor
=
torch
.
sparse_coo_tensor
(
indices
,
values
,
shape
,
dtype
=
torch
.
float32
)
encoded
=
_encode_tensor
(
malicious_tensor
)
with
pytest
.
raises
((
RuntimeError
,
ValueError
)):
renderer
.
load_prompt_embeds
(
encoded
)
class
TestImageEmbedsValidation
:
"""Test sparse tensor validation in image embeddings (Chat API)."""
def
test_valid_dense_tensor_accepted
(
self
):
"""Baseline: Valid dense tensors should work normally."""
io_handler
=
ImageEmbeddingMediaIO
()
valid_tensor
=
_create_valid_dense_tensor
()
encoded
=
_encode_tensor
(
valid_tensor
)
# Should not raise any exception
result
=
io_handler
.
load_base64
(
""
,
encoded
.
decode
(
"utf-8"
))
assert
result
.
shape
==
valid_tensor
.
shape
def
test_valid_sparse_tensor_accepted
(
self
):
"""Baseline: Valid sparse tensors should load successfully."""
io_handler
=
AudioEmbeddingMediaIO
()
valid_sparse
=
_create_valid_sparse_tensor
()
encoded
=
_encode_tensor
(
valid_sparse
)
# Should not raise any exception (sparse tensors remain sparse)
result
=
io_handler
.
load_base64
(
""
,
encoded
.
decode
(
"utf-8"
))
assert
result
.
shape
==
valid_sparse
.
shape
def
test_malicious_sparse_tensor_rejected
(
self
):
"""Security: Malicious sparse tensors should be rejected."""
io_handler
=
ImageEmbeddingMediaIO
()
malicious_tensor
=
_create_malicious_sparse_tensor
()
encoded
=
_encode_tensor
(
malicious_tensor
)
# Should raise RuntimeError due to invalid sparse tensor
with
pytest
.
raises
((
RuntimeError
,
ValueError
))
as
exc_info
:
io_handler
.
load_base64
(
""
,
encoded
.
decode
(
"utf-8"
))
error_msg
=
str
(
exc_info
.
value
).
lower
()
assert
"sparse"
in
error_msg
or
"index"
in
error_msg
or
"bounds"
in
error_msg
def
test_load_bytes_validates
(
self
):
"""Security: Validation should also work for load_bytes method."""
io_handler
=
ImageEmbeddingMediaIO
()
malicious_tensor
=
_create_malicious_sparse_tensor
()
buffer
=
io
.
BytesIO
()
torch
.
save
(
malicious_tensor
,
buffer
)
buffer
.
seek
(
0
)
with
pytest
.
raises
((
RuntimeError
,
ValueError
)):
io_handler
.
load_bytes
(
buffer
.
read
())
class
TestAudioEmbedsValidation
:
"""Test sparse tensor validation in audio embeddings (Chat API)."""
def
test_valid_dense_tensor_accepted
(
self
):
"""Baseline: Valid dense tensors should work normally."""
io_handler
=
AudioEmbeddingMediaIO
()
valid_tensor
=
_create_valid_dense_tensor
()
encoded
=
_encode_tensor
(
valid_tensor
)
# Should not raise any exception
result
=
io_handler
.
load_base64
(
""
,
encoded
.
decode
(
"utf-8"
))
assert
result
.
shape
==
valid_tensor
.
shape
def
test_valid_sparse_tensor_accepted
(
self
):
"""Baseline: Valid sparse tensors should be converted successfully."""
io_handler
=
AudioEmbeddingMediaIO
()
valid_sparse
=
_create_valid_sparse_tensor
()
encoded
=
_encode_tensor
(
valid_sparse
)
# Should not raise any exception
result
=
io_handler
.
load_base64
(
""
,
encoded
.
decode
(
"utf-8"
))
assert
result
.
is_sparse
is
False
def
test_malicious_sparse_tensor_rejected
(
self
):
"""Security: Malicious sparse tensors should be rejected."""
io_handler
=
AudioEmbeddingMediaIO
()
malicious_tensor
=
_create_malicious_sparse_tensor
()
encoded
=
_encode_tensor
(
malicious_tensor
)
# Should raise RuntimeError due to invalid sparse tensor
with
pytest
.
raises
((
RuntimeError
,
ValueError
))
as
exc_info
:
io_handler
.
load_base64
(
""
,
encoded
.
decode
(
"utf-8"
))
error_msg
=
str
(
exc_info
.
value
).
lower
()
assert
"sparse"
in
error_msg
or
"index"
in
error_msg
or
"bounds"
in
error_msg
def
test_load_bytes_validates
(
self
):
"""Security: Validation should also work for load_bytes method."""
io_handler
=
AudioEmbeddingMediaIO
()
malicious_tensor
=
_create_malicious_sparse_tensor
()
buffer
=
io
.
BytesIO
()
torch
.
save
(
malicious_tensor
,
buffer
)
buffer
.
seek
(
0
)
with
pytest
.
raises
((
RuntimeError
,
ValueError
)):
io_handler
.
load_bytes
(
buffer
.
read
())
class
TestSparseTensorValidationIntegration
:
"""
These tests verify the complete attack chain is blocked at all entry points.
"""
def
test_attack_scenario_completions_api
(
self
,
model_config
):
"""
Simulate a complete attack through the Completions API.
Attack scenario:
1. Attacker crafts malicious sparse tensor
2. Encodes it as base64
3. Sends to /v1/completions with prompt_embeds parameter
4. Server should reject before memory corruption occurs
"""
renderer
=
CompletionRenderer
(
model_config
)
# Step 1-2: Attacker creates malicious payload
attack_payload
=
_encode_tensor
(
_create_malicious_sparse_tensor
())
# Step 3-4: Server processes and should reject
with
pytest
.
raises
((
RuntimeError
,
ValueError
)):
renderer
.
load_prompt_embeds
(
attack_payload
)
def
test_attack_scenario_chat_api_image
(
self
):
"""
Simulate attack through Chat API with image_embeds.
Verifies the image embeddings path is protected.
"""
io_handler
=
ImageEmbeddingMediaIO
()
attack_payload
=
_encode_tensor
(
_create_malicious_sparse_tensor
())
with
pytest
.
raises
((
RuntimeError
,
ValueError
)):
io_handler
.
load_base64
(
""
,
attack_payload
.
decode
(
"utf-8"
))
def
test_attack_scenario_chat_api_audio
(
self
):
"""
Simulate attack through Chat API with audio_embeds.
Verifies the audio embeddings path is protected.
"""
io_handler
=
AudioEmbeddingMediaIO
()
attack_payload
=
_encode_tensor
(
_create_malicious_sparse_tensor
())
with
pytest
.
raises
((
RuntimeError
,
ValueError
)):
io_handler
.
load_base64
(
""
,
attack_payload
.
decode
(
"utf-8"
))
def
test_multiple_valid_embeddings_in_batch
(
self
,
model_config
):
"""
Regression test: Multiple valid embeddings should still work.
Ensures the fix doesn't break legitimate batch processing.
"""
renderer
=
CompletionRenderer
(
model_config
)
valid_tensors
=
[
_encode_tensor
(
_create_valid_dense_tensor
()),
_encode_tensor
(
_create_valid_dense_tensor
()),
_encode_tensor
(
_create_valid_dense_tensor
()),
]
# Should process all without error
result
=
renderer
.
load_prompt_embeds
(
valid_tensors
)
assert
len
(
result
)
==
3
def
test_mixed_valid_and_malicious_rejected
(
self
,
model_config
):
"""
Security: Batch with one malicious tensor should be rejected.
Even if most tensors are valid, a single malicious one should
cause rejection of the entire batch.
"""
renderer
=
CompletionRenderer
(
model_config
)
mixed_batch
=
[
_encode_tensor
(
_create_valid_dense_tensor
()),
_encode_tensor
(
_create_malicious_sparse_tensor
()),
# Malicious
_encode_tensor
(
_create_valid_dense_tensor
()),
]
# Should fail on the malicious tensor
with
pytest
.
raises
((
RuntimeError
,
ValueError
)):
renderer
.
load_prompt_embeds
(
mixed_batch
)
# Pytest fixtures
@
pytest
.
fixture
def
model_config
():
"""Mock ModelConfig for testing."""
from
vllm.config
import
ModelConfig
return
ModelConfig
(
model
=
"facebook/opt-125m"
,
tokenizer
=
"facebook/opt-125m"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
dtype
=
"float32"
,
seed
=
0
,
enable_prompt_embeds
=
True
,
# Required for prompt embeds tests
)
tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py
View file @
a3f8d5dd
...
@@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import (
...
@@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction_streaming
,
run_tool_extraction_streaming
,
)
)
from
vllm.entrypoints.openai.protocol
import
FunctionCall
from
vllm.entrypoints.openai.protocol
import
FunctionCall
from
vllm.entrypoints.openai.tool_parsers
import
ToolParser
,
ToolParserManager
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tool_parsers
import
ToolParser
,
ToolParserManager
SIMPLE_ARGS_DICT
=
{
SIMPLE_ARGS_DICT
=
{
"action"
:
"create"
,
"action"
:
"create"
,
...
...
tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py
View file @
a3f8d5dd
...
@@ -6,8 +6,8 @@ import json
...
@@ -6,8 +6,8 @@ import json
import
pytest
import
pytest
from
vllm.entrypoints.openai.protocol
import
ChatCompletionRequest
from
vllm.entrypoints.openai.protocol
import
ChatCompletionRequest
from
vllm.entrypoints.openai.tool_parsers.hermes_tool_parser
import
Hermes2ProToolParser
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tool_parsers.hermes_tool_parser
import
Hermes2ProToolParser
from
....utils
import
RemoteOpenAIServer
from
....utils
import
RemoteOpenAIServer
...
...
tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py
View file @
a3f8d5dd
...
@@ -12,7 +12,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
...
@@ -12,7 +12,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction_streaming
,
run_tool_extraction_streaming
,
)
)
from
vllm.entrypoints.openai.protocol
import
FunctionCall
,
ToolCall
from
vllm.entrypoints.openai.protocol
import
FunctionCall
,
ToolCall
from
vllm.
entrypoints.openai.
tool_parsers
import
ToolParser
,
ToolParserManager
from
vllm.tool_parsers
import
ToolParser
,
ToolParserManager
def
make_tool_call
(
name
,
arguments
):
def
make_tool_call
(
name
,
arguments
):
...
...
tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py
View file @
a3f8d5dd
...
@@ -6,8 +6,8 @@ from unittest.mock import MagicMock, patch
...
@@ -6,8 +6,8 @@ from unittest.mock import MagicMock, patch
import
pytest
import
pytest
from
vllm.entrypoints.openai.protocol
import
ExtractedToolCallInformation
from
vllm.entrypoints.openai.protocol
import
ExtractedToolCallInformation
from
vllm.entrypoints.openai.tool_parsers.llama_tool_parser
import
Llama3JsonToolParser
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tool_parsers.llama_tool_parser
import
Llama3JsonToolParser
@
pytest
.
fixture
@
pytest
.
fixture
...
...
tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py
View file @
a3f8d5dd
...
@@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import (
...
@@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction_streaming
,
run_tool_extraction_streaming
,
)
)
from
vllm.entrypoints.openai.protocol
import
FunctionCall
from
vllm.entrypoints.openai.protocol
import
FunctionCall
from
vllm.entrypoints.openai.tool_parsers
import
ToolParser
,
ToolParserManager
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tool_parsers
import
ToolParser
,
ToolParserManager
# Test cases similar to pythonic parser but with Llama4 specific format
# Test cases similar to pythonic parser but with Llama4 specific format
SIMPLE_FUNCTION_OUTPUT
=
"[get_weather(city='LA', metric='C')]"
SIMPLE_FUNCTION_OUTPUT
=
"[get_weather(city='LA', metric='C')]"
...
...
tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py
View file @
a3f8d5dd
...
@@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import (
...
@@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction_streaming
,
run_tool_extraction_streaming
,
)
)
from
vllm.entrypoints.openai.protocol
import
FunctionCall
from
vllm.entrypoints.openai.protocol
import
FunctionCall
from
vllm.entrypoints.openai.tool_parsers
import
ToolParser
,
ToolParserManager
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tool_parsers
import
ToolParser
,
ToolParserManager
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT
=
"get_weather(city='San Francisco', metric='celsius')"
SIMPLE_FUNCTION_OUTPUT
=
"get_weather(city='San Francisco', metric='celsius')"
...
...
tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py
View file @
a3f8d5dd
...
@@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import (
...
@@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction_streaming
,
run_tool_extraction_streaming
,
)
)
from
vllm.entrypoints.openai.protocol
import
FunctionCall
from
vllm.entrypoints.openai.protocol
import
FunctionCall
from
vllm.entrypoints.openai.tool_parsers
import
ToolParser
,
ToolParserManager
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tool_parsers
import
ToolParser
,
ToolParserManager
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT
=
"get_weather(city='San Francisco', metric='celsius')"
SIMPLE_FUNCTION_OUTPUT
=
"get_weather(city='San Francisco', metric='celsius')"
...
...
tests/entrypoints/openai/tool_parsers/utils.py
View file @
a3f8d5dd
...
@@ -10,8 +10,8 @@ from vllm.entrypoints.openai.protocol import (
...
@@ -10,8 +10,8 @@ from vllm.entrypoints.openai.protocol import (
FunctionCall
,
FunctionCall
,
ToolCall
,
ToolCall
,
)
)
from
vllm.entrypoints.openai.tool_parsers
import
ToolParser
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tool_parsers
import
ToolParser
class
StreamingToolReconstructor
:
class
StreamingToolReconstructor
:
...
...
tests/entrypoints/openai/utils.py
0 → 100644
View file @
a3f8d5dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
from
collections.abc
import
AsyncGenerator
from
typing
import
Any
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionResponse
,
ChatCompletionResponseChoice
,
ChatCompletionStreamResponse
,
ChatMessage
,
UsageInfo
,
)
async
def
accumulate_streaming_response
(
stream_generator
:
AsyncGenerator
[
str
,
None
],
)
->
ChatCompletionResponse
:
"""
Accumulate streaming SSE chunks into a complete ChatCompletionResponse.
This helper parses the SSE format and builds up the complete response
by combining all the delta chunks.
"""
accumulated_content
=
""
accumulated_reasoning
=
None
accumulated_tool_calls
:
list
[
dict
[
str
,
Any
]]
=
[]
role
=
None
finish_reason
=
None
response_id
=
None
created
=
None
model
=
None
index
=
0
async
for
chunk_str
in
stream_generator
:
# Skip empty lines and [DONE] marker
if
not
chunk_str
.
strip
()
or
chunk_str
.
strip
()
==
"data: [DONE]"
:
continue
# Parse SSE format: "data: {json}\n\n"
if
chunk_str
.
startswith
(
"data: "
):
json_str
=
chunk_str
[
6
:].
strip
()
try
:
chunk_data
=
json
.
loads
(
json_str
)
# print(f"DEBUG: Parsed chunk_data: {chunk_data}")
chunk
=
ChatCompletionStreamResponse
(
**
chunk_data
)
# Store metadata from first chunk
if
response_id
is
None
:
response_id
=
chunk
.
id
created
=
chunk
.
created
model
=
chunk
.
model
# Process each choice in the chunk
for
choice
in
chunk
.
choices
:
if
choice
.
delta
.
role
:
role
=
choice
.
delta
.
role
if
choice
.
delta
.
content
:
accumulated_content
+=
choice
.
delta
.
content
if
choice
.
delta
.
reasoning
:
if
accumulated_reasoning
is
None
:
accumulated_reasoning
=
""
accumulated_reasoning
+=
choice
.
delta
.
reasoning
if
choice
.
delta
.
tool_calls
:
# Accumulate tool calls
for
tool_call_delta
in
choice
.
delta
.
tool_calls
:
# Find or create the tool call at this index
while
len
(
accumulated_tool_calls
)
<=
tool_call_delta
.
index
:
accumulated_tool_calls
.
append
(
{
"id"
:
None
,
"type"
:
"function"
,
"function"
:
{
"name"
:
""
,
"arguments"
:
""
},
}
)
if
tool_call_delta
.
id
:
accumulated_tool_calls
[
tool_call_delta
.
index
][
"id"
]
=
(
tool_call_delta
.
id
)
if
tool_call_delta
.
function
:
if
tool_call_delta
.
function
.
name
:
accumulated_tool_calls
[
tool_call_delta
.
index
][
"function"
][
"name"
]
+=
tool_call_delta
.
function
.
name
if
tool_call_delta
.
function
.
arguments
:
accumulated_tool_calls
[
tool_call_delta
.
index
][
"function"
][
"arguments"
]
+=
tool_call_delta
.
function
.
arguments
if
choice
.
finish_reason
:
finish_reason
=
choice
.
finish_reason
if
choice
.
index
is
not
None
:
index
=
choice
.
index
except
json
.
JSONDecodeError
:
continue
# Build the final message
message_kwargs
=
{
"role"
:
role
or
"assistant"
,
"content"
:
accumulated_content
if
accumulated_content
else
None
,
"reasoning"
:
accumulated_reasoning
,
}
# Only include tool_calls if there are any
if
accumulated_tool_calls
:
message_kwargs
[
"tool_calls"
]
=
[
{
"id"
:
tc
[
"id"
],
"type"
:
tc
[
"type"
],
"function"
:
tc
[
"function"
]}
for
tc
in
accumulated_tool_calls
]
message
=
ChatMessage
(
**
message_kwargs
)
# Build the final response
choice
=
ChatCompletionResponseChoice
(
index
=
index
,
message
=
message
,
finish_reason
=
finish_reason
or
"stop"
,
)
# Create usage info (with dummy values for tests)
usage
=
UsageInfo
(
prompt_tokens
=
0
,
completion_tokens
=
0
,
total_tokens
=
0
,
)
response
=
ChatCompletionResponse
(
id
=
response_id
or
"chatcmpl-test"
,
object
=
"chat.completion"
,
created
=
created
or
0
,
model
=
model
or
"test-model"
,
choices
=
[
choice
],
usage
=
usage
,
)
return
response
def
verify_harmony_messages
(
messages
:
list
[
Any
],
expected_messages
:
list
[
dict
[
str
,
Any
]]
):
assert
len
(
messages
)
==
len
(
expected_messages
)
for
msg
,
expected
in
zip
(
messages
,
expected_messages
):
if
"role"
in
expected
:
assert
msg
.
author
.
role
==
expected
[
"role"
]
if
"author_name"
in
expected
:
assert
msg
.
author
.
name
==
expected
[
"author_name"
]
if
"channel"
in
expected
:
assert
msg
.
channel
==
expected
[
"channel"
]
if
"recipient"
in
expected
:
assert
msg
.
recipient
==
expected
[
"recipient"
]
if
"content"
in
expected
:
assert
msg
.
content
[
0
].
text
==
expected
[
"content"
]
if
"content_type"
in
expected
:
assert
msg
.
content_type
==
expected
[
"content_type"
]
if
"tool_definitions"
in
expected
:
# Check that the tool definitions match the expected list of tool names
actual_tools
=
[
t
.
name
for
t
in
msg
.
content
[
0
].
tools
[
"functions"
].
tools
]
assert
actual_tools
==
expected
[
"tool_definitions"
]
def
verify_chat_response
(
response
:
ChatCompletionResponse
,
content
:
str
|
None
=
None
,
reasoning
:
str
|
None
=
None
,
tool_calls
:
list
[
tuple
[
str
,
str
]]
|
None
=
None
,
):
assert
len
(
response
.
choices
)
==
1
message
=
response
.
choices
[
0
].
message
if
content
is
not
None
:
assert
message
.
content
==
content
else
:
assert
not
message
.
content
if
reasoning
is
not
None
:
assert
message
.
reasoning
==
reasoning
else
:
assert
not
message
.
reasoning
if
tool_calls
:
assert
message
.
tool_calls
is
not
None
assert
len
(
message
.
tool_calls
)
==
len
(
tool_calls
)
for
tc
,
(
expected_name
,
expected_args
)
in
zip
(
message
.
tool_calls
,
tool_calls
):
assert
tc
.
function
.
name
==
expected_name
assert
tc
.
function
.
arguments
==
expected_args
else
:
assert
not
message
.
tool_calls
Prev
1
2
3
4
5
6
7
8
9
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment