Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8d75f22e
Commit
8d75f22e
authored
Dec 13, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori
parents
ce888aa4
7d80c73d
Changes
656
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1268 additions
and
81 deletions
+1268
-81
tests/reasoning/test_base_thinking_reasoning_parser.py
tests/reasoning/test_base_thinking_reasoning_parser.py
+46
-1
tests/reasoning/test_deepseekv3_reasoning_parser.py
tests/reasoning/test_deepseekv3_reasoning_parser.py
+1
-0
tests/reasoning/test_holo2_reasoning_parser.py
tests/reasoning/test_holo2_reasoning_parser.py
+188
-0
tests/standalone_tests/python_only_compile.sh
tests/standalone_tests/python_only_compile.sh
+5
-1
tests/test_config.py
tests/test_config.py
+5
-5
tests/test_envs.py
tests/test_envs.py
+51
-0
tests/tool_use/test_mistral_tool_parser.py
tests/tool_use/test_mistral_tool_parser.py
+863
-0
tests/tool_use/utils.py
tests/tool_use/utils.py
+27
-1
tests/transformers_utils/test_utils.py
tests/transformers_utils/test_utils.py
+7
-5
tests/utils.py
tests/utils.py
+2
-2
tests/utils_/test_argparse_utils.py
tests/utils_/test_argparse_utils.py
+0
-22
tests/v1/attention/test_attention_splitting.py
tests/v1/attention/test_attention_splitting.py
+33
-4
tests/v1/attention/utils.py
tests/v1/attention/utils.py
+2
-2
tests/v1/core/test_reset_prefix_cache_e2e.py
tests/v1/core/test_reset_prefix_cache_e2e.py
+4
-1
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+3
-3
tests/v1/core/utils.py
tests/v1/core/utils.py
+2
-2
tests/v1/cudagraph/test_cudagraph_dispatch.py
tests/v1/cudagraph/test_cudagraph_dispatch.py
+2
-2
tests/v1/cudagraph/test_cudagraph_mode.py
tests/v1/cudagraph/test_cudagraph_mode.py
+14
-26
tests/v1/determinism/test_batch_invariance.py
tests/v1/determinism/test_batch_invariance.py
+9
-3
tests/v1/determinism/test_online_batch_invariance.py
tests/v1/determinism/test_online_batch_invariance.py
+4
-1
No files found.
Too many changes to show.
To preserve performance only
656 of 656+
files are displayed.
Plain diff
Email patch
tests/reasoning/test_base_thinking_reasoning_parser.py
View file @
8d75f22e
...
...
@@ -112,7 +112,7 @@ class TestBaseThinkingReasoningParserMethods:
"""Test the is_reasoning_end method."""
parser
=
TestThinkingReasoningParser
(
test_tokenizer
)
end_token_id
=
parser
.
end_token_id
start_token_id
=
parser
.
start_token_id
# Test with end token present
assert
parser
.
is_reasoning_end
([
1
,
2
,
end_token_id
,
4
])
is
True
...
...
@@ -122,6 +122,51 @@ class TestBaseThinkingReasoningParserMethods:
# Test with empty list
assert
parser
.
is_reasoning_end
([])
is
False
# Test with interleaved thinking
assert
parser
.
is_reasoning_end
([
1
,
start_token_id
,
2
,
end_token_id
])
is
True
assert
parser
.
is_reasoning_end
([
1
,
start_token_id
,
2
,
3
])
is
False
assert
(
parser
.
is_reasoning_end
(
[
1
,
start_token_id
,
2
,
end_token_id
,
2
,
2
,
start_token_id
]
)
is
False
)
def
test_is_reasoning_end_streaming
(
self
,
test_tokenizer
):
"""Test the is_reasoning_end_streaming method."""
parser
=
TestThinkingReasoningParser
(
test_tokenizer
)
end_token_id
=
parser
.
end_token_id
start_token_id
=
parser
.
start_token_id
assert
(
parser
.
is_reasoning_end_streaming
([
1
,
2
,
end_token_id
],
[
end_token_id
])
is
True
)
assert
parser
.
is_reasoning_end_streaming
([
1
,
2
,
3
,
4
],
[
4
])
is
False
assert
parser
.
is_reasoning_end_streaming
([],
[])
is
False
assert
(
parser
.
is_reasoning_end_streaming
(
[
1
,
start_token_id
,
2
,
end_token_id
],
[
end_token_id
]
)
is
True
)
assert
(
parser
.
is_reasoning_end_streaming
([
1
,
start_token_id
,
2
,
3
],
[
3
])
is
False
)
assert
(
parser
.
is_reasoning_end_streaming
(
[
1
,
start_token_id
,
2
,
end_token_id
,
2
,
start_token_id
,
2
],
[
2
],
)
is
False
)
assert
(
parser
.
is_reasoning_end_streaming
(
[
1
,
start_token_id
,
2
,
end_token_id
,
2
,
2
],
[
2
]
)
is
False
)
def
test_extract_content_ids
(
self
,
test_tokenizer
):
"""Test the extract_content_ids method."""
parser
=
TestThinkingReasoningParser
(
test_tokenizer
)
...
...
tests/reasoning/test_deepseekv3_reasoning_parser.py
View file @
8d75f22e
...
...
@@ -40,6 +40,7 @@ def test_identity_reasoning_parser_basic(tokenizer):
input_tokens
=
tokenizer
.
tokenize
(
input_text
)
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
input_tokens
)
assert
parser
.
is_reasoning_end
(
input_ids
)
is
True
assert
parser
.
is_reasoning_end_streaming
(
input_ids
,
input_ids
)
is
True
# Test extract_content_ids returns all input_ids
assert
parser
.
extract_content_ids
(
input_ids
)
==
input_ids
...
...
tests/reasoning/test_holo2_reasoning_parser.py
0 → 100644
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
from
transformers
import
AutoTokenizer
from
tests.reasoning.utils
import
run_reasoning_extraction
from
vllm.reasoning
import
ReasoningParser
,
ReasoningParserManager
from
vllm.reasoning.deepseek_r1_reasoning_parser
import
DeepSeekR1ReasoningParser
from
vllm.reasoning.holo2_reasoning_parser
import
Holo2ReasoningParser
from
vllm.reasoning.identity_reasoning_parser
import
IdentityReasoningParser
REASONING_MODEL_NAME
=
"HCompany/Holo2-4B"
@
pytest
.
fixture
(
scope
=
"module"
)
def
tokenizer
():
return
AutoTokenizer
.
from_pretrained
(
REASONING_MODEL_NAME
)
@
pytest
.
mark
.
parametrize
(
"thinking,expected_parser_type"
,
[
(
True
,
DeepSeekR1ReasoningParser
),
(
False
,
IdentityReasoningParser
),
],
)
def
test_parser_selection
(
tokenizer
,
thinking
,
expected_parser_type
):
parser
=
Holo2ReasoningParser
(
tokenizer
,
chat_template_kwargs
=
{
"thinking"
:
thinking
,
},
)
assert
isinstance
(
parser
.
_parser
,
expected_parser_type
)
def
test_holo2_default_parser_is_deepseekr1
(
tokenizer
):
parser
=
Holo2ReasoningParser
(
tokenizer
)
assert
isinstance
(
parser
.
_parser
,
DeepSeekR1ReasoningParser
)
def
test_holo2_supports_structured_output
(
tokenizer
):
# Structured output manager uses the reasoning parser to check if the
# reasoning content is ended before applying the grammar. The main function
# used is is_reasoning_end. This test checks if the parser is able to
# correctly identify the end of the reasoning content.
# important to not pass chat_template_kwargs here as it is done in the
# StructuredOutputManager
parser
=
Holo2ReasoningParser
(
tokenizer
)
end_token_id
=
tokenizer
.
encode
(
"</think>"
,
add_special_tokens
=
False
)[
0
]
assert
parser
.
is_reasoning_end
([
1
,
2
,
4
,
end_token_id
])
assert
not
parser
.
is_reasoning_end
([
1
,
2
,
4
])
assert
parser
.
is_reasoning_end
([
1
,
2
,
4
,
end_token_id
,
5
])
# thinking is True, non-streaming
WITH_THINK
=
{
"output"
:
"This is a reasoning section</think>This is the rest"
,
"reasoning"
:
"This is a reasoning section"
,
"content"
:
"This is the rest"
,
}
# thinking is True, streaming
WITH_THINK_STREAM
=
{
"output"
:
"This is a reasoning section</think>This is the rest"
,
"reasoning"
:
"This is a reasoning section"
,
"content"
:
"This is the rest"
,
}
# thinking is False, non-streaming
THINKING_DISABLED
=
{
"output"
:
"This is the rest"
,
"reasoning"
:
None
,
"content"
:
"This is the rest"
,
}
# thinking is False, streaming
THINKING_DISABLED_STREAM
=
{
"output"
:
"This is the rest"
,
"reasoning"
:
None
,
"content"
:
"This is the rest"
,
}
# thinking is False but the model output </think>, non-streaming
THINKING_DISABLED_WITH_CLOSE_TAG
=
{
"output"
:
"</think>This is the rest"
,
"reasoning"
:
None
,
"content"
:
"</think>This is the rest"
,
}
# thinking is False but the model output </think>, streaming
THINKING_DISABLED_WITH_CLOSE_TAG_STREAM
=
{
"output"
:
"some text</think>This is the rest"
,
"reasoning"
:
None
,
"content"
:
"some text</think>This is the rest"
,
}
COMPLETE_REASONING
=
{
"output"
:
"This is a reasoning section</think>"
,
"reasoning"
:
"This is a reasoning section"
,
"content"
:
None
,
}
TEST_CASES
=
[
pytest
.
param
(
False
,
WITH_THINK
,
None
,
id
=
"with_think"
,
),
pytest
.
param
(
True
,
WITH_THINK_STREAM
,
None
,
id
=
"with_think_stream"
,
),
pytest
.
param
(
False
,
WITH_THINK
,
{
"thinking"
:
True
},
id
=
"with_think_enabled"
,
),
pytest
.
param
(
True
,
WITH_THINK_STREAM
,
{
"thinking"
:
True
},
id
=
"with_think_stream_enabled"
,
),
pytest
.
param
(
False
,
THINKING_DISABLED
,
{
"thinking"
:
False
},
id
=
"thinking_disabled"
,
),
pytest
.
param
(
True
,
THINKING_DISABLED_STREAM
,
{
"thinking"
:
False
},
id
=
"thinking_disabled_stream"
,
),
pytest
.
param
(
False
,
THINKING_DISABLED_WITH_CLOSE_TAG
,
{
"thinking"
:
False
},
id
=
"thinking_disabled_with_close_tag"
,
),
pytest
.
param
(
True
,
THINKING_DISABLED_WITH_CLOSE_TAG_STREAM
,
{
"thinking"
:
False
},
id
=
"thinking_disabled_with_close_tag_stream"
,
),
pytest
.
param
(
False
,
COMPLETE_REASONING
,
None
,
id
=
"complete_reasoning"
,
),
pytest
.
param
(
True
,
COMPLETE_REASONING
,
None
,
id
=
"complete_reasoning_stream"
,
),
]
@
pytest
.
mark
.
parametrize
(
"streaming, param_dict, chat_template_kwargs"
,
TEST_CASES
)
def
test_reasoning
(
streaming
:
bool
,
param_dict
:
dict
,
chat_template_kwargs
:
dict
|
None
,
tokenizer
,
):
output
=
tokenizer
.
tokenize
(
param_dict
[
"output"
])
output_tokens
:
list
[
str
]
=
[
tokenizer
.
convert_tokens_to_string
([
token
])
for
token
in
output
]
parser
:
ReasoningParser
=
ReasoningParserManager
.
get_reasoning_parser
(
"holo2"
)(
tokenizer
,
chat_template_kwargs
=
chat_template_kwargs
,
)
reasoning
,
content
=
run_reasoning_extraction
(
parser
,
output_tokens
,
streaming
=
streaming
)
assert
reasoning
==
param_dict
[
"reasoning"
]
assert
content
==
param_dict
[
"content"
]
tests/standalone_tests/python_only_compile.sh
View file @
8d75f22e
...
...
@@ -5,6 +5,10 @@
set
-e
set
-x
merge_base_commit
=
$(
git merge-base HEAD origin/main
)
echo
"Current merge base commit with main:
$merge_base_commit
"
git show
--oneline
-s
$merge_base_commit
cd
/vllm-workspace/
# uninstall vllm
...
...
@@ -18,7 +22,7 @@ apt autoremove -y
echo
'import os; os.system("touch /tmp/changed.file")'
>>
vllm/__init__.py
VLLM_
TEST_USE_
PRECOMPILED_
NIGHTLY_WHEEL
=
1
VLLM_USE_PRECOMPILED
=
1 pip3
install
-vvv
-e
.
VLLM_PRECOMPILED_
WHEEL_COMMIT
=
$merge_base_commit
VLLM_USE_PRECOMPILED
=
1 pip3
install
-vvv
-e
.
# Run the script
python3
-c
'import vllm'
...
...
tests/test_config.py
View file @
8d75f22e
...
...
@@ -97,7 +97,7 @@ def test_update_config():
(
"intfloat/multilingual-e5-small"
,
"pooling"
,
"none"
,
"embed"
),
(
"jason9693/Qwen2.5-1.5B-apeach"
,
"pooling"
,
"classify"
,
"classify"
),
(
"cross-encoder/ms-marco-MiniLM-L-6-v2"
,
"pooling"
,
"none"
,
"classify"
),
(
"Qwen/Qwen2.5-Math-RM-72B"
,
"pooling"
,
"none"
,
"
rewar
d"
),
(
"Qwen/Qwen2.5-Math-RM-72B"
,
"pooling"
,
"none"
,
"
embe
d"
),
(
"openai/whisper-small"
,
"generate"
,
"none"
,
"transcription"
),
],
)
...
...
@@ -629,8 +629,8 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
(
"internlm/internlm2-1_8b-reward"
,
"decoder"
,
Fals
e
,
"Pooling models with all pooling
does not
support chunked prefill."
,
Tru
e
,
"Pooling models with
causal attn and
all pooling support chunked prefill."
,
),
(
"BAAI/bge-base-en"
,
...
...
@@ -748,8 +748,8 @@ def test_is_chunked_prefill_supported(
(
"internlm/internlm2-1_8b-reward"
,
"decoder"
,
Fals
e
,
"Pooling models with all pooling
does not
support prefix caching."
,
Tru
e
,
"Pooling models with
causal attn and
all pooling support prefix caching."
,
),
(
"BAAI/bge-base-en"
,
...
...
tests/test_envs.py
View file @
8d75f22e
...
...
@@ -365,3 +365,54 @@ class TestEnvSetWithChoices:
with
patch
.
dict
(
os
.
environ
,
{
"TEST_ENV"
:
"option1,option1,option2"
}):
env_func
=
env_set_with_choices
(
"TEST_ENV"
,
[],
[
"option1"
,
"option2"
])
assert
env_func
()
==
{
"option1"
,
"option2"
}
class
TestVllmConfigureLogging
:
"""Test cases for VLLM_CONFIGURE_LOGGING environment variable."""
def
test_configure_logging_defaults_to_true
(
self
):
"""Test that VLLM_CONFIGURE_LOGGING defaults to True when not set."""
# Ensure the env var is not set
with
patch
.
dict
(
os
.
environ
,
{},
clear
=
False
):
if
"VLLM_CONFIGURE_LOGGING"
in
os
.
environ
:
del
os
.
environ
[
"VLLM_CONFIGURE_LOGGING"
]
# Clear cache if it exists
if
hasattr
(
envs
.
__getattr__
,
"cache_clear"
):
envs
.
__getattr__
.
cache_clear
()
result
=
envs
.
VLLM_CONFIGURE_LOGGING
assert
result
is
True
assert
isinstance
(
result
,
bool
)
def
test_configure_logging_with_zero_string
(
self
):
"""Test that VLLM_CONFIGURE_LOGGING='0' evaluates to False."""
with
patch
.
dict
(
os
.
environ
,
{
"VLLM_CONFIGURE_LOGGING"
:
"0"
}):
# Clear cache if it exists
if
hasattr
(
envs
.
__getattr__
,
"cache_clear"
):
envs
.
__getattr__
.
cache_clear
()
result
=
envs
.
VLLM_CONFIGURE_LOGGING
assert
result
is
False
assert
isinstance
(
result
,
bool
)
def
test_configure_logging_with_one_string
(
self
):
"""Test that VLLM_CONFIGURE_LOGGING='1' evaluates to True."""
with
patch
.
dict
(
os
.
environ
,
{
"VLLM_CONFIGURE_LOGGING"
:
"1"
}):
# Clear cache if it exists
if
hasattr
(
envs
.
__getattr__
,
"cache_clear"
):
envs
.
__getattr__
.
cache_clear
()
result
=
envs
.
VLLM_CONFIGURE_LOGGING
assert
result
is
True
assert
isinstance
(
result
,
bool
)
def
test_configure_logging_with_invalid_value_raises_error
(
self
):
"""Test that invalid VLLM_CONFIGURE_LOGGING value raises ValueError."""
with
patch
.
dict
(
os
.
environ
,
{
"VLLM_CONFIGURE_LOGGING"
:
"invalid"
}):
# Clear cache if it exists
if
hasattr
(
envs
.
__getattr__
,
"cache_clear"
):
envs
.
__getattr__
.
cache_clear
()
with
pytest
.
raises
(
ValueError
,
match
=
"invalid literal for int"
):
_
=
envs
.
VLLM_CONFIGURE_LOGGING
tests/tool_use/test_mistral_tool_parser.py
0 → 100644
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
from
collections.abc
import
Generator
import
partial_json_parser
import
pytest
from
mistral_common.protocol.instruct.messages
import
AssistantMessage
from
mistral_common.protocol.instruct.request
import
InstructRequest
from
mistral_common.protocol.instruct.tool_calls
import
FunctionCall
,
ToolCall
from
partial_json_parser.core.options
import
Allow
from
vllm.entrypoints.openai.protocol
import
DeltaMessage
,
DeltaToolCall
from
vllm.entrypoints.openai.tool_parsers.mistral_tool_parser
import
MistralToolParser
from
vllm.tokenizers
import
(
MistralTokenizer
,
TokenizerLike
,
get_tokenizer
,
)
from
vllm.tokenizers.detokenizer_utils
import
detokenize_incrementally
@
pytest
.
fixture
(
scope
=
"module"
)
def
mistral_pre_v11_tokenizer
():
MODEL
=
"mistralai/Mistral-7B-Instruct-v0.3"
return
get_tokenizer
(
tokenizer_name
=
MODEL
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
mistral_tokenizer
():
MODEL
=
"mistralai/Mistral-Small-3.2-24B-Instruct-2506"
return
get_tokenizer
(
tokenizer_name
=
MODEL
,
tokenizer_mode
=
"mistral"
)
@
pytest
.
fixture
def
mistral_pre_v11_tool_parser
(
mistral_pre_v11_tokenizer
):
return
MistralToolParser
(
mistral_pre_v11_tokenizer
)
@
pytest
.
fixture
def
mistral_tool_parser
(
mistral_tokenizer
):
return
MistralToolParser
(
mistral_tokenizer
)
def
assert_tool_calls
(
actual_tool_calls
:
list
[
ToolCall
]
|
list
[
DeltaToolCall
],
expected_tool_calls
:
list
[
ToolCall
],
):
assert
len
(
actual_tool_calls
)
==
len
(
expected_tool_calls
)
for
actual_tool_call
,
expected_tool_call
in
zip
(
actual_tool_calls
,
expected_tool_calls
):
assert
isinstance
(
actual_tool_call
.
id
,
str
)
assert
len
(
actual_tool_call
.
id
)
==
9
if
isinstance
(
actual_tool_call
,
ToolCall
):
assert
actual_tool_call
.
type
==
"function"
elif
isinstance
(
actual_tool_call
,
DeltaToolCall
):
assert
actual_tool_call
.
function
is
not
None
assert
actual_tool_call
.
function
.
name
is
not
None
assert
actual_tool_call
.
function
.
arguments
is
not
None
assert
actual_tool_call
.
function
is
not
None
assert
actual_tool_call
.
function
.
name
==
expected_tool_call
.
function
.
name
,
(
f
"got wrong function name:$
{
actual_tool_call
.
function
.
name
}
"
)
assert
(
actual_tool_call
.
function
.
arguments
==
expected_tool_call
.
function
.
arguments
),
f
"got wrong function argument:$
{
actual_tool_call
.
function
.
arguments
}
"
def
fix_tool_call_tokenization
(
tokens
:
list
[
int
],
mistral_tool_parser
:
MistralToolParser
,
mistral_tokenizer
:
TokenizerLike
,
):
"""
Replaces the textual token sequence for [TOOL_CALLS]
with its single special token ID.
"""
textual_tool_call_token_ids
=
mistral_tokenizer
.
encode
(
text
=
mistral_tool_parser
.
bot_token
,
add_special_tokens
=
False
,
)
# textual_tool_call_token_ids must not contain special tokens like bos, eos etc
special_tool_call_token_ids
=
[
mistral_tool_parser
.
bot_token_id
]
# If the input is too short to contain the sequence, no replacement is possible
if
not
tokens
or
len
(
tokens
)
<
len
(
textual_tool_call_token_ids
):
return
tokens
result_tokens
=
[]
i
=
0
target_len
=
len
(
textual_tool_call_token_ids
)
while
i
<
len
(
tokens
):
# Check if the slice from the current position matches the target sequence
if
tokens
[
i
:
i
+
target_len
]
==
textual_tool_call_token_ids
:
# If it matches, add the replacement and jump the index forward
result_tokens
.
extend
(
special_tool_call_token_ids
)
i
+=
target_len
else
:
# Otherwise, just add the current token and move to the next one
result_tokens
.
append
(
tokens
[
i
])
i
+=
1
return
result_tokens
def
stream_delta_message_generator
(
mistral_tool_parser
:
MistralToolParser
,
mistral_tokenizer
:
TokenizerLike
,
model_output
:
str
|
None
,
tools
:
list
[
tuple
[
str
,
str
]]
|
None
,
)
->
Generator
[
DeltaMessage
,
None
,
None
]:
if
(
isinstance
(
mistral_tokenizer
,
MistralTokenizer
)
and
mistral_tokenizer
.
version
>=
11
):
# With the newer versions of the tokenizer,
# we cannot tokenize free text
# so we need to create a list of messages to get tokenized
assert
tools
is
not
None
assistant_msg
=
AssistantMessage
(
tool_calls
=
[
ToolCall
(
function
=
FunctionCall
(
name
=
name
,
arguments
=
arg
,
)
)
for
(
name
,
arg
)
in
tools
],
)
request
=
InstructRequest
(
messages
=
[
assistant_msg
],
)
all_token_ids
=
mistral_tokenizer
.
instruct
.
encode_instruct
(
request
).
tokens
else
:
# Older versions of the tokenizer are
# able to encode directly the model's output (free text) into tokens
assert
model_output
is
not
None
all_token_ids
=
mistral_tokenizer
.
encode
(
model_output
,
add_special_tokens
=
False
)
all_token_ids
=
fix_tool_call_tokenization
(
all_token_ids
,
mistral_tool_parser
,
mistral_tokenizer
)
previous_text
=
""
previous_tokens
=
None
prefix_offset
=
0
read_offset
=
0
for
i
,
delta_token
in
enumerate
(
all_token_ids
):
delta_token_ids
=
[
delta_token
]
previous_token_ids
=
all_token_ids
[:
i
]
current_token_ids
=
all_token_ids
[:
i
+
1
]
(
new_tokens
,
delta_text
,
new_prefix_offset
,
new_read_offset
)
=
(
detokenize_incrementally
(
tokenizer
=
mistral_tokenizer
,
all_input_ids
=
current_token_ids
,
prev_tokens
=
previous_tokens
,
prefix_offset
=
prefix_offset
,
read_offset
=
read_offset
,
skip_special_tokens
=
isinstance
(
mistral_tokenizer
,
MistralTokenizer
),
spaces_between_special_tokens
=
True
,
)
)
current_text
=
previous_text
+
delta_text
delta_message
=
mistral_tool_parser
.
extract_tool_calls_streaming
(
previous_text
,
current_text
,
delta_text
,
previous_token_ids
,
current_token_ids
,
delta_token_ids
,
request
=
None
,
# type: ignore[arg-type]
)
if
delta_message
:
yield
delta_message
previous_text
=
current_text
previous_tokens
=
(
previous_tokens
+
new_tokens
if
previous_tokens
else
new_tokens
)
prefix_offset
=
new_prefix_offset
read_offset
=
new_read_offset
def
test_extract_tool_calls_no_tools
(
mistral_pre_v11_tool_parser
):
model_output
=
"This is a test"
extracted_tool_calls
=
mistral_pre_v11_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
None
)
# type: ignore[arg-type]
assert
not
extracted_tool_calls
.
tools_called
assert
extracted_tool_calls
.
tool_calls
==
[]
assert
extracted_tool_calls
.
content
==
model_output
@
pytest
.
mark
.
parametrize
(
ids
=
[
"single_tool_add"
,
"single_tool_weather"
,
"argument_before_name"
,
"argument_before_name_and_name_in_argument"
,
],
argnames
=
[
"model_output"
,
"expected_tool_calls"
,
"expected_content"
],
argvalues
=
[
(
"""[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"add"
,
arguments
=
json
.
dumps
({
"a"
:
3.5
,
"b"
:
4
})
)
)
],
None
,
),
(
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
(
{
"city"
:
"San Francisco"
,
"state"
:
"CA"
,
"unit"
:
"celsius"
}
),
)
)
],
None
,
),
(
"""[TOOL_CALLS] [{"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
(
{
"city"
:
"San Francisco"
,
"state"
:
"CA"
,
"unit"
:
"celsius"
}
),
)
)
],
None
,
),
(
"""[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_age"
,
arguments
=
json
.
dumps
(
{
"name"
:
"John Doe"
,
}
),
)
)
],
None
,
),
],
)
def
test_extract_tool_calls_pre_v11_tokenizer
(
mistral_pre_v11_tool_parser
,
model_output
,
expected_tool_calls
,
expected_content
):
extracted_tool_calls
=
mistral_pre_v11_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
None
)
# type: ignore[arg-type]
assert
extracted_tool_calls
.
tools_called
assert_tool_calls
(
extracted_tool_calls
.
tool_calls
,
expected_tool_calls
)
assert
extracted_tool_calls
.
content
==
expected_content
@
pytest
.
mark
.
parametrize
(
ids
=
[
"single_tool_add"
,
"single_tool_weather"
,
"multiple_tool_calls"
,
],
argnames
=
[
"model_output"
,
"expected_tool_calls"
,
"expected_content"
],
argvalues
=
[
(
"""[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"add_this_and_that"
,
arguments
=
json
.
dumps
({
"a"
:
3.5
,
"b"
:
4
}),
)
)
],
None
,
),
(
"""[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
(
{
"city"
:
"San Francisco"
,
"state"
:
"CA"
,
"unit"
:
"celsius"
}
),
)
)
],
None
,
),
(
"""[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"add"
,
arguments
=
json
.
dumps
({
"a"
:
3.5
,
"b"
:
4
})
)
),
ToolCall
(
function
=
FunctionCall
(
name
=
"multiply"
,
arguments
=
json
.
dumps
({
"a"
:
3
,
"b"
:
6
})
)
),
],
None
,
),
],
)
def
test_extract_tool_calls
(
mistral_tool_parser
,
model_output
,
expected_tool_calls
,
expected_content
):
extracted_tool_calls
=
mistral_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
None
)
# type: ignore[arg-type]
assert
extracted_tool_calls
.
tools_called
assert_tool_calls
(
extracted_tool_calls
.
tool_calls
,
expected_tool_calls
)
assert
extracted_tool_calls
.
content
==
expected_content
def
_test_extract_tool_calls_streaming
(
tool_parser
,
tokenizer
,
model_output
,
tools
,
expected_tool_calls
,
expected_content
):
other_content
:
str
=
""
function_names
:
list
[
str
]
=
[]
function_args_strs
:
list
[
str
]
=
[]
tool_call_idx
:
int
=
-
1
tool_call_ids
:
list
[
str
|
None
]
=
[]
for
delta_message
in
stream_delta_message_generator
(
tool_parser
,
tokenizer
,
model_output
,
tools
):
# role should never be streamed from tool parser
assert
not
delta_message
.
role
if
delta_message
.
content
:
other_content
+=
delta_message
.
content
streamed_tool_calls
=
delta_message
.
tool_calls
if
streamed_tool_calls
and
len
(
streamed_tool_calls
)
>
0
:
# make sure only one diff is present - correct even for parallel
assert
len
(
streamed_tool_calls
)
==
1
tool_call
=
streamed_tool_calls
[
0
]
assert
len
(
tool_parser
.
prev_tool_call_arr
)
>
0
# if a new tool is being called, set up empty arguments
if
tool_call
.
index
!=
tool_call_idx
:
tool_call_idx
=
tool_call
.
index
function_args_strs
.
append
(
""
)
tool_call_ids
.
append
(
None
)
# if a tool call ID is streamed, make sure one hasn't been already
if
tool_call
.
id
and
not
tool_call_ids
[
tool_call
.
index
]:
tool_call_ids
[
tool_call
.
index
]
=
tool_call
.
id
# if parts of the function start being streamed
if
tool_call
.
function
:
# if the function name is defined, set it. it should be streamed
# IN ENTIRETY, exactly one time.
if
tool_call
.
function
.
name
:
assert
isinstance
(
tool_call
.
function
.
name
,
str
)
function_names
.
append
(
tool_call
.
function
.
name
)
if
tool_call
.
function
.
arguments
:
# make sure they're a string and then add them to the list
assert
isinstance
(
tool_call
.
function
.
arguments
,
str
)
function_args_strs
[
tool_call
.
index
]
+=
tool_call
.
function
.
arguments
assert
other_content
==
expected_content
actual_tool_calls
=
[
ToolCall
(
id
=
tool_call_id
,
function
=
FunctionCall
(
name
=
function_name
,
arguments
=
partial_json_parser
.
ensure_json
(
function_args_str
,
Allow
.
OBJ
|
Allow
.
STR
),
),
)
for
tool_call_id
,
function_name
,
function_args_str
in
zip
(
tool_call_ids
,
function_names
,
function_args_strs
)
]
assert_tool_calls
(
actual_tool_calls
,
expected_tool_calls
)
@
pytest
.
mark
.
parametrize
(
ids
=
[
"no_tools"
,
"single_tool_add"
,
"single_tool_add_strings"
,
"single_tool_weather"
,
"argument_before_name"
,
"argument_before_name_and_name_in_argument"
,
"multiple_tools"
,
],
argnames
=
[
"model_output"
,
"expected_tool_calls"
,
"expected_content"
],
argvalues
=
[
(
"""This is a test"""
,
[],
"""This is a test"""
),
(
"""[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"add"
,
arguments
=
json
.
dumps
({
"a"
:
3
,
"b"
:
4
})
)
)
],
""
,
),
(
"""[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"add"
,
arguments
=
json
.
dumps
({
"a"
:
"3"
,
"b"
:
"4"
})
)
)
],
""
,
),
(
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
(
{
"city"
:
"San Francisco"
,
"state"
:
"CA"
,
"unit"
:
"celsius"
}
),
)
)
],
""
,
),
(
"""[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
(
{
"city"
:
"San Francisco"
,
"state"
:
"CA"
,
"unit"
:
"celsius"
}
),
)
)
],
""
,
),
(
"""[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_age"
,
arguments
=
json
.
dumps
(
{
"name"
:
"John Doe"
,
}
),
)
)
],
""
,
),
(
"""[TOOL_CALLS] [{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"add"
,
arguments
=
json
.
dumps
({
"a"
:
3.5
,
"b"
:
4
})
)
),
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
(
{
"city"
:
"San Francisco"
,
"state"
:
"CA"
,
"unit"
:
"celsius"
}
),
)
),
],
""
,
),
],
)
def
test_extract_tool_calls_streaming_pre_v11_tokenizer
(
mistral_pre_v11_tool_parser
,
mistral_pre_v11_tokenizer
,
model_output
,
expected_tool_calls
,
expected_content
,
):
_test_extract_tool_calls_streaming
(
mistral_pre_v11_tool_parser
,
mistral_pre_v11_tokenizer
,
model_output
,
None
,
expected_tool_calls
,
expected_content
,
)
@
pytest
.
mark
.
parametrize
(
ids
=
[
"single_tool_add"
,
"single_tool_add_strings"
,
"multiple_tools"
,
],
argnames
=
[
"tools"
,
"expected_tool_calls"
,
"expected_content"
],
argvalues
=
[
(
[(
"add"
,
'{"a": 3, "b": 4}'
)],
# [TOOL_CALLS]add{"a": 3, "b": 4}
[
ToolCall
(
function
=
FunctionCall
(
name
=
"add"
,
arguments
=
json
.
dumps
({
"a"
:
3
,
"b"
:
4
})
)
)
],
""
,
),
(
[(
"add_two_strings"
,
'{"a": "3", "b": "4"}'
)],
# [TOOL_CALLS]add_two_strings{"a": "3", "b": "4"}
[
ToolCall
(
function
=
FunctionCall
(
name
=
"add_two_strings"
,
arguments
=
json
.
dumps
({
"a"
:
"3"
,
"b"
:
"4"
}),
)
)
],
""
,
),
(
[
(
"add"
,
'{"a": 3.5, "b": 4}'
),
(
"get_current_weather"
,
'{"city": "San Francisco", "state": "CA", "unit": "celsius"}'
,
# noqa: E501
),
],
# [TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"} # noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"add"
,
arguments
=
json
.
dumps
({
"a"
:
3.5
,
"b"
:
4
})
)
),
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
(
{
"city"
:
"San Francisco"
,
"state"
:
"CA"
,
"unit"
:
"celsius"
}
),
)
),
],
""
,
),
],
)
def
test_extract_tool_calls_streaming
(
mistral_tool_parser
,
mistral_tokenizer
,
tools
,
expected_tool_calls
,
expected_content
,
):
_test_extract_tool_calls_streaming
(
mistral_tool_parser
,
mistral_tokenizer
,
None
,
tools
,
expected_tool_calls
,
expected_content
,
)
@
pytest
.
mark
.
parametrize
(
ids
=
[
"single_tool_add"
,
"single_tool_weather"
,
"multiple_tool_calls"
,
"content_before_tool"
,
"complex"
,
],
argnames
=
[
"model_output"
,
"expected_tool_calls"
,
"expected_content"
],
argvalues
=
[
(
"""[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"add_this_and_that"
,
arguments
=
json
.
dumps
({
"a"
:
3.5
,
"b"
:
4
}),
)
)
],
""
,
),
(
"""[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
(
{
"city"
:
"San Francisco"
,
"state"
:
"CA"
,
"unit"
:
"celsius"
}
),
)
)
],
""
,
),
(
"""[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"add"
,
arguments
=
json
.
dumps
({
"a"
:
3.5
,
"b"
:
4
})
)
),
ToolCall
(
function
=
FunctionCall
(
name
=
"multiply"
,
arguments
=
json
.
dumps
({
"a"
:
3
,
"b"
:
6
})
)
),
],
""
,
),
(
# Additional content should not be after the tool calls
"""bla[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"add_this_and_that"
,
arguments
=
json
.
dumps
({
"a"
:
3.5
,
"b"
:
4
}),
)
)
],
"bla"
,
),
(
# Complex
"""[TOOL_CALLS]bash{"command": "print(
\\
"hello world!
\\
")
\\
nre.compile(r
\'
{}
\'
)"}"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"bash"
,
arguments
=
json
.
dumps
(
{
"command"
:
"print(
\"
hello world!
\"
)
\n
re.compile(r'{}')"
}
),
)
)
],
""
,
),
],
)
def
test_extract_tool_calls_streaming_one_chunk
(
mistral_tool_parser
,
mistral_tokenizer
,
model_output
,
expected_tool_calls
,
expected_content
,
):
if
isinstance
(
mistral_tokenizer
,
MistralTokenizer
):
all_token_ids
=
mistral_tokenizer
.
encode
(
model_output
)
else
:
all_token_ids
=
mistral_tokenizer
.
encode
(
model_output
,
add_special_tokens
=
False
)
all_token_ids
=
fix_tool_call_tokenization
(
all_token_ids
,
mistral_tool_parser
,
mistral_tokenizer
)
delta_message
=
mistral_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
""
,
current_text
=
model_output
,
delta_text
=
model_output
,
previous_token_ids
=
[],
current_token_ids
=
all_token_ids
,
delta_token_ids
=
all_token_ids
,
request
=
None
,
)
# type: ignore[arg-type]
assert
isinstance
(
delta_message
,
DeltaMessage
)
assert
len
(
delta_message
.
tool_calls
)
==
len
(
expected_tool_calls
)
assert_tool_calls
(
delta_message
.
tool_calls
,
expected_tool_calls
)
if
delta_message
.
content
is
None
:
assert
expected_content
==
""
else
:
assert
delta_message
.
content
==
expected_content
@
pytest
.
mark
.
parametrize
(
ids
=
[
"no_tools"
,
"single_tool_add"
,
"single_tool_add_strings"
,
"single_tool_weather"
,
"argument_before_name"
,
"argument_before_name_and_name_in_argument"
,
"multiple_tools"
,
],
argnames
=
[
"model_output"
,
"expected_tool_calls"
,
"expected_content"
],
argvalues
=
[
(
"""This is a test"""
,
[],
"""This is a test"""
),
(
"""[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"add"
,
arguments
=
json
.
dumps
({
"a"
:
3
,
"b"
:
4
})
)
)
],
""
,
),
(
"""[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"add"
,
arguments
=
json
.
dumps
({
"a"
:
"3"
,
"b"
:
"4"
})
)
)
],
""
,
),
(
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
(
{
"city"
:
"San Francisco"
,
"state"
:
"CA"
,
"unit"
:
"celsius"
}
),
)
)
],
""
,
),
(
"""[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
(
{
"city"
:
"San Francisco"
,
"state"
:
"CA"
,
"unit"
:
"celsius"
}
),
)
)
],
""
,
),
(
"""[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_age"
,
arguments
=
json
.
dumps
(
{
"name"
:
"John Doe"
,
}
),
)
)
],
""
,
),
(
"""[TOOL_CALLS] [{"arguments": {"a": 3.5, "b": 4}, "name": "add"}, {"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"add"
,
arguments
=
json
.
dumps
({
"a"
:
3.5
,
"b"
:
4
})
)
),
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
(
{
"city"
:
"San Francisco"
,
"state"
:
"CA"
,
"unit"
:
"celsius"
}
),
)
),
],
""
,
),
],
)
def
test_extract_tool_calls_streaming_pre_v11_tokenizer_one_chunk
(
mistral_pre_v11_tool_parser
,
mistral_pre_v11_tokenizer
,
model_output
,
expected_tool_calls
,
expected_content
,
):
if
isinstance
(
mistral_pre_v11_tokenizer
,
MistralTokenizer
):
all_token_ids
=
mistral_pre_v11_tokenizer
.
encode
(
model_output
)
else
:
all_token_ids
=
mistral_pre_v11_tokenizer
.
encode
(
model_output
,
add_special_tokens
=
False
)
all_token_ids
=
fix_tool_call_tokenization
(
all_token_ids
,
mistral_pre_v11_tool_parser
,
mistral_pre_v11_tokenizer
)
delta_message
=
mistral_pre_v11_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
""
,
current_text
=
model_output
,
delta_text
=
model_output
,
previous_token_ids
=
[],
current_token_ids
=
all_token_ids
,
delta_token_ids
=
all_token_ids
,
request
=
None
,
)
# type: ignore[arg-type]
assert
isinstance
(
delta_message
,
DeltaMessage
)
assert
len
(
delta_message
.
tool_calls
)
==
len
(
expected_tool_calls
)
assert_tool_calls
(
delta_message
.
tool_calls
,
expected_tool_calls
)
if
delta_message
.
content
is
None
:
assert
expected_content
==
""
else
:
assert
delta_message
.
content
==
expected_content
tests/tool_use/utils.py
View file @
8d75f22e
...
...
@@ -123,7 +123,7 @@ CONFIGS: dict[str, ServerConfig] = {
"supports_parallel"
:
True
,
"extended"
:
True
,
},
"mistral"
:
{
"mistral
-7b
"
:
{
"model"
:
"mistralai/Mistral-7B-Instruct-v0.3"
,
"arguments"
:
[
"--enforce-eager"
,
...
...
@@ -145,6 +145,32 @@ CONFIGS: dict[str, ServerConfig] = {
"call the tool. Otherwise, answer the user's query directly "
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally."
,
"supports_parallel"
:
True
,
},
"mistral-small-3.2"
:
{
"model"
:
"mistralai/Mistral-Small-3.2-24B-Instruct-2506"
,
"arguments"
:
[
"--enforce-eager"
,
"--no-enable-prefix-caching"
,
"--tool-call-parser"
,
"mistral"
,
"--tokenizer-mode"
,
"mistral"
,
"--config-format"
,
"mistral"
,
"--load-format"
,
"mistral"
,
"--tensor-parallel-size"
,
"4"
,
'--ignore-patterns="consolidated.safetensors"'
,
],
"system_prompt"
:
"You are a helpful assistant with access to tools. If a tool"
" that you have would be helpful to answer a user query, "
"call the tool. Otherwise, answer the user's query directly "
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally."
,
"supports_parallel"
:
True
,
"extended"
:
True
,
},
# FIXME: This test currently fails, need to debug why.
# "granite20b": {
...
...
tests/transformers_utils/test_utils.py
View file @
8d75f22e
...
...
@@ -5,13 +5,15 @@ from unittest.mock import patch
import
pytest
from
vllm.transformers_utils.gguf_utils
import
(
is_gguf
,
is_remote_gguf
,
split_remote_gguf
,
)
from
vllm.transformers_utils.utils
import
(
is_cloud_storage
,
is_gcs
,
is_gguf
,
is_remote_gguf
,
is_s3
,
split_remote_gguf
,
)
...
...
@@ -132,7 +134,7 @@ class TestSplitRemoteGGUF:
class
TestIsGGUF
:
"""Test is_gguf utility function."""
@
patch
(
"vllm.transformers_utils.utils.check_gguf_file"
,
return_value
=
True
)
@
patch
(
"vllm.transformers_utils.
gguf_
utils.check_gguf_file"
,
return_value
=
True
)
def
test_is_gguf_with_local_file
(
self
,
mock_check_gguf
):
"""Test is_gguf with local GGUF file."""
assert
is_gguf
(
"/path/to/model.gguf"
)
...
...
@@ -149,7 +151,7 @@ class TestIsGGUF:
assert
not
is_gguf
(
"repo/model:quant"
)
assert
not
is_gguf
(
"repo/model:INVALID"
)
@
patch
(
"vllm.transformers_utils.utils.check_gguf_file"
,
return_value
=
False
)
@
patch
(
"vllm.transformers_utils.
gguf_
utils.check_gguf_file"
,
return_value
=
False
)
def
test_is_gguf_false
(
self
,
mock_check_gguf
):
"""Test is_gguf returns False for non-GGUF models."""
assert
not
is_gguf
(
"unsloth/Qwen3-0.6B"
)
...
...
tests/utils.py
View file @
8d75f22e
...
...
@@ -1225,9 +1225,9 @@ def get_attn_backend_list_based_on_platform() -> list[str]:
try
:
import
aiter
# noqa: F401
attn_backend_list
.
append
(
"
FLASH_ATTN
"
)
attn_backend_list
.
append
(
"
ROCM_AITER_FA
"
)
except
Exception
:
print
(
"Skip
FLASH_ATTN
on ROCm as aiter is not installed"
)
print
(
"Skip
ROCM_AITER_FA
on ROCm as aiter is not installed"
)
return
attn_backend_list
elif
current_platform
.
is_xpu
():
...
...
tests/utils_/test_argparse_utils.py
View file @
8d75f22e
...
...
@@ -458,25 +458,3 @@ def test_flat_product():
(
3
,
4
,
"a"
,
5
,
6
),
(
3
,
4
,
"b"
,
5
,
6
),
]
def
test_o_legacy_syntax_deprecation
(
caplog_vllm
):
"""Test that -O.* dotted syntax emits warnings and converts correctly to -cc syntax."""
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"-cc"
,
"--compilation-config"
,
type
=
json
.
loads
)
# Test that -O.backend gets converted correctly AND emits warning
args
=
parser
.
parse_args
([
"-O.backend=eager"
])
assert
args
.
compilation_config
==
{
"backend"
:
"eager"
}
# Check that deprecation warning was logged
assert
len
(
caplog_vllm
.
records
)
>=
1
assert
(
"The -O.* dotted syntax for --compilation-config is deprecated"
in
caplog_vllm
.
text
)
# Test that -O.mode gets converted correctly
# Note: warning_once won't emit again in same session
args
=
parser
.
parse_args
([
"-O.mode=2"
])
assert
args
.
compilation_config
==
{
"mode"
:
2
}
tests/v1/attention/test_attention_splitting.py
View file @
8d75f22e
...
...
@@ -13,7 +13,7 @@ from vllm.v1.attention.backends.utils import (
split_attn_metadata
,
split_decodes_and_prefills
,
)
from
vllm.v1.worker.ubatch_utils
import
create_ubatch_slices
from
vllm.v1.worker.ubatch_utils
import
maybe_
create_ubatch_slices
@
pytest
.
fixture
...
...
@@ -154,7 +154,10 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
def
apply_split_decodes_and_prefills
(
query_lens
:
list
[
int
],
decode_threshold
:
int
,
require_uniform
:
bool
query_lens
:
list
[
int
],
decode_threshold
:
int
,
require_uniform
:
bool
,
padded_num_tokens
:
int
|
None
=
None
,
):
"""Helper function to apply split_decodes_and_prefills and return
the results."""
...
...
@@ -165,6 +168,10 @@ def apply_split_decodes_and_prefills(
block_size
=
16
,
device
=
device
,
)
if
padded_num_tokens
is
not
None
:
common_metadata
.
num_actual_tokens
=
padded_num_tokens
return
split_decodes_and_prefills
(
common_metadata
,
decode_threshold
=
decode_threshold
,
...
...
@@ -271,6 +278,22 @@ def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes():
assert
num_prefill_tokens
==
(
sum
(
query_lens
)
-
2
)
# rest of the tokens
def
test_split_decodes_and_prefills_uniform_padded_batch_all_same
():
"""uniform batch where all query lengths are identical with 0 length padded reqs."""
# All query lengths are 2, with decode_threshold=3 (so 2 <= 3)
# This triggers the padded uniform path at line 891
query_lens
=
[
2
,
2
,
2
,
0
]
padded_num_tokens
=
8
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
apply_split_decodes_and_prefills
(
query_lens
,
3
,
True
,
padded_num_tokens
)
)
# With uniform batch, all requests are treated as decodes
assert
num_decodes
==
4
assert
num_prefills
==
0
assert
num_decode_tokens
==
padded_num_tokens
assert
num_prefill_tokens
==
0
@
pytest
.
mark
.
parametrize
(
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs"
,
[
...
...
@@ -294,8 +317,14 @@ def test_prefill_split_across_ubatches(
qsl_np
=
common
.
query_start_loc_cpu
.
numpy
()
num_tokens
=
common
.
num_actual_tokens
ubatch_slices
=
create_ubatch_slices
(
num_scheduled_tokens
,
split_point
)
assert
len
(
ubatch_slices
)
==
2
ubatch_slices
,
_
=
maybe_create_ubatch_slices
(
True
,
num_scheduled_tokens
,
num_tokens
,
batch_spec
.
batch_size
,
split_point
=
split_point
,
)
assert
ubatch_slices
is
not
None
and
len
(
ubatch_slices
)
==
2
first_meta
=
_make_metadata_with_slice
(
ubatch_slices
[
0
],
common
)
second_meta
=
_make_metadata_with_slice
(
ubatch_slices
[
1
],
common
)
...
...
tests/v1/attention/utils.py
View file @
8d75f22e
...
...
@@ -106,8 +106,8 @@ def create_common_attn_metadata(
query_start_loc
=
query_start_loc
,
query_start_loc_cpu
=
query_start_loc_cpu
,
seq_lens
=
seq_lens
,
seq_lens_cpu
=
seq_lens_cpu
,
num_computed_tokens_cpu
=
num_computed_tokens_cpu
,
_
seq_lens_cpu
=
seq_lens_cpu
,
_
num_computed_tokens_cpu
=
num_computed_tokens_cpu
,
num_reqs
=
batch_spec
.
batch_size
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
max_query_len
,
...
...
tests/v1/core/test_reset_prefix_cache_e2e.py
View file @
8d75f22e
...
...
@@ -11,7 +11,9 @@ PROMPTS = [
]
def
test_reset_prefix_cache_e2e
():
def
test_reset_prefix_cache_e2e
(
monkeypatch
):
# "spawn" is required for test to be deterministic
monkeypatch
.
setenv
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
"spawn"
)
engine_args
=
EngineArgs
(
model
=
"Qwen/Qwen3-0.6B"
,
gpu_memory_utilization
=
0.2
,
...
...
@@ -19,6 +21,7 @@ def test_reset_prefix_cache_e2e():
max_num_batched_tokens
=
32
,
max_model_len
=
2048
,
compilation_config
=
{
"mode"
:
0
},
dtype
=
"float16"
,
)
engine
=
LLMEngine
.
from_engine_args
(
engine_args
)
sampling_params
=
SamplingParams
(
...
...
tests/v1/core/test_scheduler.py
View file @
8d75f22e
...
...
@@ -1536,7 +1536,7 @@ def create_scheduler_with_priority(
)
kv_transfer_config
=
(
KVTransferConfig
(
kv_connector
=
"
SharedStorag
eConnector"
,
kv_connector
=
"
Exampl
eConnector"
,
kv_role
=
"kv_both"
,
kv_connector_extra_config
=
{
"shared_storage_path"
:
"local_storage"
},
)
...
...
@@ -1552,7 +1552,7 @@ def create_scheduler_with_priority(
ec_transfer_config
=
(
ECTransferConfig
(
ec_connector
=
"EC
SharedStorag
eConnector"
,
ec_connector
=
"EC
Exampl
eConnector"
,
ec_role
=
ec_role
,
ec_connector_extra_config
=
{
"shared_storage_path"
:
"/tmp/ec_test"
},
)
...
...
@@ -2413,7 +2413,7 @@ def _assert_right_ec_connector_metadata(
metadata_dict
=
{
mm_data
.
mm_hash
:
mm_data
for
mm_data
in
metadata
.
mm_datas
}
# Check all required identifiers exist in metadata; and no extra
# In EC
SharedStorag
eConnector format
# In EC
Exampl
eConnector format
# NOTE: even having same identifier, the mm_features can be different
# since their mm_position can be in different offsets, etc
identifiers_dict
=
{
f
.
identifier
for
f
in
mm_features_list
}
...
...
tests/v1/core/utils.py
View file @
8d75f22e
...
...
@@ -108,7 +108,7 @@ def create_scheduler(
)
elif
use_kv_connector
:
kv_transfer_config
=
KVTransferConfig
(
kv_connector
=
"
SharedStorag
eConnector"
,
kv_connector
=
"
Exampl
eConnector"
,
kv_role
=
"kv_both"
,
kv_connector_extra_config
=
{
"shared_storage_path"
:
"local_storage"
},
)
...
...
@@ -121,7 +121,7 @@ def create_scheduler(
ec_transfer_config
=
(
ECTransferConfig
(
ec_connector
=
"EC
SharedStorag
eConnector"
,
ec_connector
=
"EC
Exampl
eConnector"
,
ec_role
=
ec_role
,
ec_connector_extra_config
=
{
"shared_storage_path"
:
"/tmp/ec_test"
},
)
...
...
tests/v1/cudagraph/test_cudagraph_dispatch.py
View file @
8d75f22e
...
...
@@ -161,10 +161,10 @@ class TestCudagraphDispatcher:
assert
rt_mode
==
CUDAGraphMode
.
NONE
assert
key
==
BatchDescriptor
(
num_tokens
=
15
)
# 4.
Cascade attention
should have a fall back mode
# 4.
disable_full
should have a fall back mode
(e.g., cascade attention)
desc_full_exact
=
BatchDescriptor
(
num_tokens
=
8
,
uniform
=
False
)
rt_mode
,
key
=
dispatcher
.
dispatch
(
num_tokens
=
8
,
uniform_decode
=
False
,
has_lora
=
False
,
use_cascade_attn
=
True
num_tokens
=
8
,
uniform_decode
=
False
,
has_lora
=
False
,
disable_full
=
True
)
if
"PIECEWISE"
in
cudagraph_mode_str
:
# string contains check
assert
rt_mode
==
CUDAGraphMode
.
PIECEWISE
...
...
tests/v1/cudagraph/test_cudagraph_mode.py
View file @
8d75f22e
...
...
@@ -100,32 +100,20 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
# test cudagraph_mode with different compilation mode.
# (backend_name, cudagraph_mode, compilation_mode, supported)
if
current_platform
.
is_rocm
():
combo_cases_2
=
[
(
"RocmAttn"
,
"FULL"
,
CompilationMode
.
NONE
,
True
),
(
"RocmAttn"
,
"FULL"
,
CompilationMode
.
VLLM_COMPILE
,
True
),
(
"RocmAttn"
,
"PIECEWISE"
,
CompilationMode
.
NONE
,
False
),
(
"RocmAttn"
,
"PIECEWISE"
,
CompilationMode
.
VLLM_COMPILE
,
True
),
(
"RocmAttn"
,
"FULL_AND_PIECEWISE"
,
CompilationMode
.
NONE
,
False
),
(
"RocmAttn"
,
"FULL_AND_PIECEWISE"
,
CompilationMode
.
VLLM_COMPILE
,
True
),
(
"RocmAttn"
,
"FULL_DECODE_ONLY"
,
CompilationMode
.
NONE
,
True
),
(
"RocmAttn"
,
"FULL_DECODE_ONLY"
,
CompilationMode
.
VLLM_COMPILE
,
True
),
(
"RocmAttn"
,
"NONE"
,
CompilationMode
.
NONE
,
True
),
(
"RocmAttn"
,
"NONE"
,
CompilationMode
.
VLLM_COMPILE
,
True
),
]
else
:
combo_cases_2
=
[
(
"FA2"
,
"FULL"
,
CompilationMode
.
NONE
,
True
),
(
"FA2"
,
"FULL"
,
CompilationMode
.
VLLM_COMPILE
,
True
),
(
"FA2"
,
"PIECEWISE"
,
CompilationMode
.
NONE
,
True
),
(
"FA2"
,
"PIECEWISE"
,
CompilationMode
.
VLLM_COMPILE
,
True
),
(
"FA2"
,
"FULL_AND_PIECEWISE"
,
CompilationMode
.
NONE
,
True
),
(
"FA2"
,
"FULL_AND_PIECEWISE"
,
CompilationMode
.
VLLM_COMPILE
,
True
),
(
"FA2"
,
"FULL_DECODE_ONLY"
,
CompilationMode
.
NONE
,
True
),
(
"FA2"
,
"FULL_DECODE_ONLY"
,
CompilationMode
.
VLLM_COMPILE
,
True
),
(
"FA2"
,
"NONE"
,
CompilationMode
.
NONE
,
True
),
(
"FA2"
,
"NONE"
,
CompilationMode
.
VLLM_COMPILE
,
True
),
]
attn_backend
=
"RocmAttn"
if
current_platform
.
is_rocm
()
else
"FA2"
combo_cases_2
=
[
(
attn_backend
,
"FULL"
,
CompilationMode
.
NONE
,
True
),
(
attn_backend
,
"FULL"
,
CompilationMode
.
VLLM_COMPILE
,
True
),
(
attn_backend
,
"PIECEWISE"
,
CompilationMode
.
NONE
,
True
),
(
attn_backend
,
"PIECEWISE"
,
CompilationMode
.
VLLM_COMPILE
,
True
),
(
attn_backend
,
"FULL_AND_PIECEWISE"
,
CompilationMode
.
NONE
,
True
),
(
attn_backend
,
"FULL_AND_PIECEWISE"
,
CompilationMode
.
VLLM_COMPILE
,
True
),
(
attn_backend
,
"FULL_DECODE_ONLY"
,
CompilationMode
.
NONE
,
True
),
(
attn_backend
,
"FULL_DECODE_ONLY"
,
CompilationMode
.
VLLM_COMPILE
,
True
),
(
attn_backend
,
"NONE"
,
CompilationMode
.
NONE
,
True
),
(
attn_backend
,
"NONE"
,
CompilationMode
.
VLLM_COMPILE
,
True
),
]
@
pytest
.
mark
.
parametrize
(
...
...
tests/v1/determinism/test_batch_invariance.py
View file @
8d75f22e
...
...
@@ -10,6 +10,7 @@ from utils import (
BACKENDS
,
_extract_step_logprobs
,
_random_prompt
,
is_device_capability_below_90
,
resolve_model_name
,
skip_unsupported
,
)
...
...
@@ -17,6 +18,8 @@ from utils import (
import
vllm.model_executor.layers.batch_invariant
as
batch_invariant
from
vllm
import
LLM
,
SamplingParams
IS_DEVICE_CAPABILITY_BELOW_90
=
is_device_capability_below_90
()
@
skip_unsupported
@
pytest
.
mark
.
timeout
(
1000
)
...
...
@@ -185,11 +188,12 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
llm
=
LLM
(
model
=
model_name
,
tensor_parallel_size
=
tp_size
,
enable_prefix_caching
=
False
,
#
enable_prefix_caching=False,
max_num_seqs
=
32
,
max_model_len
=
8192
,
dtype
=
"bfloat16"
,
# not everything is supported
gpu_memory_utilization
=
0.9
,
enforce_eager
=
IS_DEVICE_CAPABILITY_BELOW_90
,
)
# Use more realistic prompts for better token generation
...
...
@@ -394,6 +398,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
max_model_len
=
2048
,
dtype
=
"bfloat16"
,
enable_prefix_caching
=
False
,
enforce_eager
=
IS_DEVICE_CAPABILITY_BELOW_90
,
)
prompt
=
"the capital of france is"
...
...
@@ -457,10 +462,10 @@ def test_logprobs_without_batch_invariance_should_fail(
llm
=
LLM
(
model
=
model_name
,
tensor_parallel_size
=
tp_size
,
enable_prefix_caching
=
False
,
max_num_seqs
=
32
,
max_model_len
=
8192
,
dtype
=
"bfloat16"
,
enforce_eager
=
IS_DEVICE_CAPABILITY_BELOW_90
,
)
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
...
...
@@ -681,10 +686,10 @@ def test_decode_logprobs_match_prefill_logprobs(
llm
=
LLM
(
model
=
model_name
,
tensor_parallel_size
=
tp_size
,
enable_prefix_caching
=
False
,
max_num_seqs
=
32
,
max_model_len
=
8192
,
dtype
=
"bfloat16"
,
enforce_eager
=
IS_DEVICE_CAPABILITY_BELOW_90
,
)
# Use a few test prompts
...
...
@@ -929,6 +934,7 @@ def LLM_with_max_seqs(
dtype
=
"bfloat16"
,
tensor_parallel_size
=
int
(
os
.
getenv
(
"VLLM_TP_SIZE"
,
"1"
)),
enable_prefix_caching
=
False
,
enforce_eager
=
IS_DEVICE_CAPABILITY_BELOW_90
,
# Enable for MOE models
# enable_expert_parallel=True,
)
tests/v1/determinism/test_online_batch_invariance.py
View file @
8d75f22e
...
...
@@ -153,7 +153,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
}
tp_size
=
os
.
getenv
(
"VLLM_TP_SIZE"
,
"1"
)
server_args
:
list
[
str
]
=
[]
server_args
:
list
[
str
]
=
[
"--max-model-len=8192"
,
"--max-num-seqs=32"
,
]
if
tp_size
:
server_args
+=
[
"-tp"
,
tp_size
]
...
...
Prev
1
…
11
12
13
14
15
16
17
18
19
…
33
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment