Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7e63ef82
Commit
7e63ef82
authored
Jan 21, 2026
by
zhuwenwen
Browse files
Merge tag 'v0.14.0' into v0.14.0-dev
parents
8cbcac5d
b17039bc
Changes
681
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1155 additions
and
333 deletions
+1155
-333
tests/rocm/aiter/test_mla_fp8_support_check.py
tests/rocm/aiter/test_mla_fp8_support_check.py
+118
-0
tests/standalone_tests/lazy_imports.py
tests/standalone_tests/lazy_imports.py
+6
-25
tests/standalone_tests/python_only_compile.sh
tests/standalone_tests/python_only_compile.sh
+20
-8
tests/standalone_tests/pytorch_nightly_dependency.sh
tests/standalone_tests/pytorch_nightly_dependency.sh
+6
-1
tests/test_attention_backend_registry.py
tests/test_attention_backend_registry.py
+169
-0
tests/test_config.py
tests/test_config.py
+99
-44
tests/test_pooling_params.py
tests/test_pooling_params.py
+15
-15
tests/test_routing_simulator.py
tests/test_routing_simulator.py
+1
-1
tests/tokenizers_/test_basic.py
tests/tokenizers_/test_basic.py
+5
-0
tests/tokenizers_/test_detokenize.py
tests/tokenizers_/test_detokenize.py
+2
-1
tests/tool_parsers/test_functiongemma_tool_parser.py
tests/tool_parsers/test_functiongemma_tool_parser.py
+154
-0
tests/tool_parsers/test_kimi_k2_tool_parser.py
tests/tool_parsers/test_kimi_k2_tool_parser.py
+192
-191
tests/tool_parsers/test_mistral_tool_parser.py
tests/tool_parsers/test_mistral_tool_parser.py
+34
-2
tests/tool_use/test_chat_completions.py
tests/tool_use/test_chat_completions.py
+42
-0
tests/tool_use/test_minimax_m2_tool_parser.py
tests/tool_use/test_minimax_m2_tool_parser.py
+119
-0
tests/tool_use/test_tool_choice_required.py
tests/tool_use/test_tool_choice_required.py
+33
-0
tests/utils.py
tests/utils.py
+28
-6
tests/utils_/test_torch_utils.py
tests/utils_/test_torch_utils.py
+9
-21
tests/v1/attention/test_attention_backends.py
tests/v1/attention/test_attention_backends.py
+97
-17
tests/v1/attention/test_attention_backends_selection.py
tests/v1/attention/test_attention_backends_selection.py
+6
-1
No files found.
Too many changes to show.
To preserve performance only
681 of 681+
files are displayed.
Plain diff
Email patch
tests/rocm/aiter/test_mla_fp8_support_check.py
0 → 100644
View file @
7e63ef82
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for AITER MLA FP8 support detection.
These tests verify that the _check_aiter_mla_fp8_support() function
correctly handles various error conditions without crashing.
"""
from
unittest.mock
import
patch
import
pytest
class
TestAiterMlaFp8SupportCheck
:
"""Test cases for _check_aiter_mla_fp8_support() function."""
def
setup_method
(
self
):
"""Reset the global cache before each test."""
import
vllm._aiter_ops
as
aiter_ops
aiter_ops
.
_AITER_MLA_SUPPORTS_FP8
=
None
@
patch
(
"vllm._aiter_ops.is_aiter_found_and_supported"
,
return_value
=
True
)
def
test_import_error_handling
(
self
,
mock_supported
):
"""Test that ImportError is handled gracefully."""
import
vllm._aiter_ops
as
aiter_ops
from
vllm._aiter_ops
import
_check_aiter_mla_fp8_support
aiter_ops
.
_AITER_MLA_SUPPORTS_FP8
=
None
# Should return False without raising
with
patch
(
"vllm._aiter_ops.inspect.signature"
,
side_effect
=
ImportError
(
"No module"
),
):
result
=
_check_aiter_mla_fp8_support
()
assert
result
is
False
@
patch
(
"vllm._aiter_ops.is_aiter_found_and_supported"
,
return_value
=
True
)
def
test_module_not_found_error_handling
(
self
,
mock_supported
):
"""Test that ModuleNotFoundError is handled gracefully."""
import
vllm._aiter_ops
as
aiter_ops
from
vllm._aiter_ops
import
_check_aiter_mla_fp8_support
aiter_ops
.
_AITER_MLA_SUPPORTS_FP8
=
None
with
patch
(
"vllm._aiter_ops.inspect.signature"
,
side_effect
=
ModuleNotFoundError
(
"Module not found"
),
):
# Should return False without raising
assert
_check_aiter_mla_fp8_support
()
is
False
# Cache should be set to False
assert
aiter_ops
.
_AITER_MLA_SUPPORTS_FP8
is
False
@
patch
(
"vllm._aiter_ops.is_aiter_found_and_supported"
,
return_value
=
True
)
def
test_attribute_error_handling
(
self
,
mock_supported
):
"""Test that AttributeError is handled gracefully."""
import
vllm._aiter_ops
as
aiter_ops
from
vllm._aiter_ops
import
_check_aiter_mla_fp8_support
aiter_ops
.
_AITER_MLA_SUPPORTS_FP8
=
None
with
patch
(
"vllm._aiter_ops.inspect.signature"
,
side_effect
=
AttributeError
(
"No attribute"
),
):
assert
_check_aiter_mla_fp8_support
()
is
False
assert
aiter_ops
.
_AITER_MLA_SUPPORTS_FP8
is
False
@
patch
(
"vllm._aiter_ops.is_aiter_found_and_supported"
,
return_value
=
True
)
def
test_value_error_handling
(
self
,
mock_supported
):
"""Test that ValueError is handled gracefully (no signature)."""
import
vllm._aiter_ops
as
aiter_ops
from
vllm._aiter_ops
import
_check_aiter_mla_fp8_support
aiter_ops
.
_AITER_MLA_SUPPORTS_FP8
=
None
with
patch
(
"vllm._aiter_ops.inspect.signature"
,
side_effect
=
ValueError
(
"No signature"
),
):
assert
_check_aiter_mla_fp8_support
()
is
False
assert
aiter_ops
.
_AITER_MLA_SUPPORTS_FP8
is
False
@
patch
(
"vllm._aiter_ops.is_aiter_found_and_supported"
,
return_value
=
True
)
def
test_type_error_handling
(
self
,
mock_supported
):
"""Test that TypeError is handled gracefully (not callable)."""
import
vllm._aiter_ops
as
aiter_ops
from
vllm._aiter_ops
import
_check_aiter_mla_fp8_support
aiter_ops
.
_AITER_MLA_SUPPORTS_FP8
=
None
with
patch
(
"vllm._aiter_ops.inspect.signature"
,
side_effect
=
TypeError
(
"Not a callable"
),
):
assert
_check_aiter_mla_fp8_support
()
is
False
assert
aiter_ops
.
_AITER_MLA_SUPPORTS_FP8
is
False
@
patch
(
"vllm._aiter_ops.is_aiter_found_and_supported"
,
return_value
=
True
)
def
test_result_caching
(
self
,
mock_supported
):
"""Test that the result is cached after first check."""
import
vllm._aiter_ops
as
aiter_ops
# Set cache to True
aiter_ops
.
_AITER_MLA_SUPPORTS_FP8
=
True
from
vllm._aiter_ops
import
_check_aiter_mla_fp8_support
# Should return cached value without re-checking
result
=
_check_aiter_mla_fp8_support
()
assert
result
is
True
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
,
"-v"
])
tests/standalone_tests/lazy_imports.py
View file @
7e63ef82
...
@@ -5,9 +5,6 @@
...
@@ -5,9 +5,6 @@
# The utility function cannot be placed in `vllm.utils`
# The utility function cannot be placed in `vllm.utils`
# this needs to be a standalone script
# this needs to be a standalone script
import
sys
import
sys
from
contextlib
import
nullcontext
from
vllm_test_utils
import
BlameResult
,
blame
# List of modules that should not be imported too early.
# List of modules that should not be imported too early.
# Lazy import `torch._inductor.async_compile` to avoid creating
# Lazy import `torch._inductor.async_compile` to avoid creating
...
@@ -16,26 +13,10 @@ from vllm_test_utils import BlameResult, blame
...
@@ -16,26 +13,10 @@ from vllm_test_utils import BlameResult, blame
# `cv2` can easily mess up the environment.
# `cv2` can easily mess up the environment.
module_names
=
[
"torch._inductor.async_compile"
,
"cv2"
]
module_names
=
[
"torch._inductor.async_compile"
,
"cv2"
]
# set all modules in `module_names` to be None.
# if we import any modules during `import vllm`, there would be a
# hard error and nice stacktrace on the first import.
for
module_name
in
module_names
:
sys
.
modules
[
module_name
]
=
None
# type: ignore[assignment]
def
any_module_imported
():
import
vllm
# noqa
return
any
(
module_name
in
sys
.
modules
for
module_name
in
module_names
)
# In CI, we only check finally if the module is imported.
# If it is indeed imported, we can rerun the test with `use_blame=True`,
# which will trace every function call to find the first import location,
# and help find the root cause.
# We don't run it in CI by default because it is slow.
use_blame
=
False
context
=
blame
(
any_module_imported
)
if
use_blame
else
nullcontext
()
with
context
as
result
:
import
vllm
# noqa
if
use_blame
:
assert
isinstance
(
result
,
BlameResult
)
print
(
f
"the first import location is:
\n
{
result
.
trace_stack
}
"
)
assert
not
any_module_imported
(),
(
f
"Some the modules in
{
module_names
}
are imported. To see the first"
f
" import location, run the test with `use_blame=True`."
)
tests/standalone_tests/python_only_compile.sh
View file @
7e63ef82
...
@@ -18,25 +18,37 @@ for i in {1..5}; do
...
@@ -18,25 +18,37 @@ for i in {1..5}; do
echo
"Checking metadata.json URL (attempt
$i
)..."
echo
"Checking metadata.json URL (attempt
$i
)..."
if
curl
--fail
"
$meta_json_url
"
>
metadata.json
;
then
if
curl
--fail
"
$meta_json_url
"
>
metadata.json
;
then
echo
"INFO: metadata.json URL is valid."
echo
"INFO: metadata.json URL is valid."
# check whether it is valid json by python
# check whether it is valid json by python
(printed to stdout)
if
python3
-m
json.tool metadata.json
;
then
if
python3
-m
json.tool metadata.json
;
then
echo
"INFO: metadata.json is valid JSON. Proceeding with the test."
echo
"INFO: metadata.json is valid JSON. Proceeding with the check."
# check whether there is an object in the json matching:
# "package_name": "vllm", and "platform_tag" matches the current architecture
# see `determine_wheel_url` in setup.py for more details
if
python3
-c
"import platform as p,json as j,sys as s; d = j.load(open('metadata.json'));
\
s.exit(int(not any(o.get('package_name') == 'vllm' and p.machine() in o.get('platform_tag')
\
for o in d)))"
2>/dev/null
;
then
echo
"INFO: metadata.json contains a pre-compiled wheel for the current architecture."
break
else
echo
"WARN: metadata.json does not have a pre-compiled wheel for the current architecture."
fi
else
else
echo
"CRITICAL: metadata.json exists but is not valid JSON, please do report in #sig-ci channel!"
echo
"CRITICAL: metadata.json exists but is not valid JSON, please do report in #sig-ci channel!"
echo
"INFO: metadata.json content:"
cat
metadata.json
exit
1
exit
1
fi
fi
break
fi
fi
# failure handling
# failure handling
& retry logic
if
[
$i
-eq
5
]
;
then
if
[
$i
-eq
5
]
;
then
echo
"ERROR: metadata
.json URL
is still not va
lid
after 5 attempts."
echo
"ERROR: metadata is still not
a
va
ilable
after 5 attempts."
echo
"ERROR: Please check whether the precompiled wheel for commit
$merge_base_commit
exists
."
echo
"ERROR: Please check whether the precompiled wheel for commit
$merge_base_commit
is available
."
echo
" NOTE: If
$merge_base_commit
is a new commit on main, maybe try again after its release pipeline finishes."
echo
" NOTE: If
$merge_base_commit
is a new commit on main, maybe try again after its release pipeline finishes."
echo
" NOTE: If it fails, please report in #sig-ci channel."
echo
" NOTE: If it fails, please report in #sig-ci channel."
exit
1
exit
1
else
else
echo
"WARNING: metadata
.json URL
is not va
lid
. Retrying
in 3
minutes..."
echo
"WARNING: metadata is not
a
va
ilable
. Retrying
after 5
minutes..."
sleep
18
0
sleep
30
0
fi
fi
done
done
...
...
tests/standalone_tests/pytorch_nightly_dependency.sh
View file @
7e63ef82
...
@@ -4,6 +4,11 @@
...
@@ -4,6 +4,11 @@
set
-e
set
-e
set
-x
set
-x
if
command
-v
rocminfo
>
/dev/null 2>&1
;
then
echo
"Skipping test for ROCm platform"
exit
0
fi
cd
/vllm-workspace/
cd
/vllm-workspace/
rm
-rf
.venv
rm
-rf
.venv
...
@@ -36,7 +41,7 @@ if diff before.txt after.txt; then
...
@@ -36,7 +41,7 @@ if diff before.txt after.txt; then
echo
"torch version not overridden."
echo
"torch version not overridden."
else
else
echo
"torch version overridden by nightly_torch_test.txt,
\
echo
"torch version overridden by nightly_torch_test.txt,
\
if the dependency is not triggered by the pyt
r
och nightly test,
\
if the dependency is not triggered by the pyto
r
ch nightly test,
\
please add the dependency to the list 'white_list' in tools/pre_commit/generate_nightly_torch_test.py"
please add the dependency to the list 'white_list' in tools/pre_commit/generate_nightly_torch_test.py"
exit
1
exit
1
fi
fi
tests/test_attention_backend_registry.py
0 → 100644
View file @
7e63ef82
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionImpl
,
)
from
vllm.v1.attention.backends.registry
import
(
AttentionBackendEnum
,
MambaAttentionBackendEnum
,
register_backend
,
)
class
CustomAttentionImpl
(
AttentionImpl
):
"""Mock custom attention implementation for testing."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
()
def
forward
(
self
,
*
args
,
**
kwargs
):
"""Mock forward pass."""
pass
class
CustomAttentionBackend
(
AttentionBackend
):
"""Mock custom attention backend for testing."""
@
staticmethod
def
get_name
():
return
"CUSTOM"
@
staticmethod
def
get_impl_cls
():
return
CustomAttentionImpl
@
staticmethod
def
get_builder_cls
():
"""Mock builder class."""
return
None
@
staticmethod
def
get_required_kv_cache_layout
():
"""Mock KV cache layout."""
return
None
class
CustomMambaAttentionImpl
(
AttentionImpl
):
"""Mock custom mamba attention implementation for testing."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
()
def
forward
(
self
,
*
args
,
**
kwargs
):
"""Mock forward pass."""
pass
class
CustomMambaAttentionBackend
(
AttentionBackend
):
"""Mock custom mamba attention backend for testing."""
@
staticmethod
def
get_name
():
return
"CUSTOM_MAMBA"
@
staticmethod
def
get_impl_cls
():
return
CustomMambaAttentionImpl
@
staticmethod
def
get_builder_cls
():
"""Mock builder class."""
return
None
@
staticmethod
def
get_required_kv_cache_layout
():
"""Mock KV cache layout."""
return
None
def
test_custom_is_not_alias_of_any_backend
():
# Get all members of AttentionBackendEnum
all_backends
=
list
(
AttentionBackendEnum
)
# Find any aliases of CUSTOM
aliases
=
[]
for
backend
in
all_backends
:
if
backend
.
name
!=
"CUSTOM"
and
backend
is
AttentionBackendEnum
.
CUSTOM
:
aliases
.
append
(
backend
.
name
)
# CUSTOM should not be an alias of any other backend
assert
len
(
aliases
)
==
0
,
(
f
"BUG! CUSTOM is an alias of:
{
', '
.
join
(
aliases
)
}
!
\n
"
f
"CUSTOM.value =
{
repr
(
AttentionBackendEnum
.
CUSTOM
.
value
)
}
\n
"
f
"This happens when CUSTOM has the same value as another backend.
\n
"
f
"When you register to CUSTOM, you're actually registering to
{
aliases
[
0
]
}
!
\n
"
f
"All backend values:
\n
"
+
"
\n
"
.
join
(
f
"
{
b
.
name
}
:
{
repr
(
b
.
value
)
}
"
for
b
in
all_backends
)
)
# Verify CUSTOM has its own unique identity
assert
AttentionBackendEnum
.
CUSTOM
.
name
==
"CUSTOM"
,
(
f
"CUSTOM.name should be 'CUSTOM', but got '
{
AttentionBackendEnum
.
CUSTOM
.
name
}
'"
)
def
test_register_custom_backend_with_class_path
():
# Register with explicit class path
register_backend
(
backend
=
AttentionBackendEnum
.
CUSTOM
,
class_path
=
"tests.test_attention_backend_registry.CustomAttentionBackend"
,
is_mamba
=
False
,
)
# Check that CUSTOM backend is registered
assert
AttentionBackendEnum
.
CUSTOM
.
is_overridden
(),
(
"CUSTOM should be overridden after registration"
)
# Get the registered class path
class_path
=
AttentionBackendEnum
.
CUSTOM
.
get_path
()
assert
class_path
==
"tests.test_attention_backend_registry.CustomAttentionBackend"
# Get the backend class
backend_cls
=
AttentionBackendEnum
.
CUSTOM
.
get_class
()
assert
backend_cls
.
get_name
()
==
"CUSTOM"
assert
backend_cls
.
get_impl_cls
()
==
CustomAttentionImpl
def
test_mamba_custom_is_not_alias_of_any_backend
():
# Get all mamba backends
all_backends
=
list
(
MambaAttentionBackendEnum
)
# Find any aliases of CUSTOM
aliases
=
[]
for
backend
in
all_backends
:
if
backend
.
name
!=
"CUSTOM"
and
backend
is
MambaAttentionBackendEnum
.
CUSTOM
:
aliases
.
append
(
backend
.
name
)
# CUSTOM should not be an alias of any other backend
assert
len
(
aliases
)
==
0
,
(
f
"BUG! MambaAttentionBackendEnum.CUSTOM is an alias of:
{
', '
.
join
(
aliases
)
}
!
\n
"
f
"CUSTOM.value =
{
repr
(
MambaAttentionBackendEnum
.
CUSTOM
.
value
)
}
\n
"
f
"All mamba backend values:
\n
"
+
"
\n
"
.
join
(
f
"
{
b
.
name
}
:
{
repr
(
b
.
value
)
}
"
for
b
in
all_backends
)
)
def
test_register_custom_mamba_backend_with_class_path
():
# Register with explicit class path
register_backend
(
backend
=
MambaAttentionBackendEnum
.
CUSTOM
,
class_path
=
"tests.test_attention_backend_registry.CustomMambaAttentionBackend"
,
is_mamba
=
True
,
)
# Check that the backend is registered
assert
MambaAttentionBackendEnum
.
CUSTOM
.
is_overridden
()
# Get the registered class path
class_path
=
MambaAttentionBackendEnum
.
CUSTOM
.
get_path
()
assert
(
class_path
==
"tests.test_attention_backend_registry.CustomMambaAttentionBackend"
)
# Get the backend class
backend_cls
=
MambaAttentionBackendEnum
.
CUSTOM
.
get_class
()
assert
backend_cls
.
get_name
()
==
"CUSTOM_MAMBA"
assert
backend_cls
.
get_impl_cls
()
==
CustomMambaAttentionImpl
tests/test_config.py
View file @
7e63ef82
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
logging
import
logging
import
os
import
os
from
dataclasses
import
MISSING
,
Field
,
asdict
,
dataclass
,
field
from
dataclasses
import
MISSING
,
Field
,
asdict
,
dataclass
,
field
...
@@ -25,7 +26,6 @@ from vllm.config.vllm import (
...
@@ -25,7 +26,6 @@ from vllm.config.vllm import (
OPTIMIZATION_LEVEL_TO_CONFIG
,
OPTIMIZATION_LEVEL_TO_CONFIG
,
OptimizationLevel
,
OptimizationLevel
,
)
)
from
vllm.model_executor.layers.pooler
import
PoolingType
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
utils
import
models_path_prefix
from
utils
import
models_path_prefix
...
@@ -162,8 +162,9 @@ def test_get_pooling_config():
...
@@ -162,8 +162,9 @@ def test_get_pooling_config():
model_config
=
ModelConfig
(
model_id
)
model_config
=
ModelConfig
(
model_id
)
assert
model_config
.
pooler_config
is
not
None
assert
model_config
.
pooler_config
is
not
None
assert
model_config
.
pooler_config
.
normalize
assert
model_config
.
pooler_config
.
use_activation
assert
model_config
.
pooler_config
.
pooling_type
==
PoolingType
.
MEAN
.
name
assert
model_config
.
pooler_config
.
seq_pooling_type
==
"MEAN"
assert
model_config
.
pooler_config
.
tok_pooling_type
==
"ALL"
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
...
@@ -171,7 +172,7 @@ def test_get_pooling_config():
...
@@ -171,7 +172,7 @@ def test_get_pooling_config():
)
)
def
test_get_pooling_config_from_args
():
def
test_get_pooling_config_from_args
():
model_id
=
os
.
path
.
join
(
models_path_prefix
,
"sentence-transformers/all-MiniLM-L12-v2"
)
model_id
=
os
.
path
.
join
(
models_path_prefix
,
"sentence-transformers/all-MiniLM-L12-v2"
)
pooler_config
=
PoolerConfig
(
pooling_type
=
"CLS"
,
normalize
=
True
)
pooler_config
=
PoolerConfig
(
seq_
pooling_type
=
"CLS"
,
normalize
=
True
)
model_config
=
ModelConfig
(
model_id
,
pooler_config
=
pooler_config
)
model_config
=
ModelConfig
(
model_id
,
pooler_config
=
pooler_config
)
assert
asdict
(
model_config
.
pooler_config
)
==
asdict
(
pooler_config
)
assert
asdict
(
model_config
.
pooler_config
)
==
asdict
(
pooler_config
)
...
@@ -182,14 +183,25 @@ def test_get_pooling_config_from_args():
...
@@ -182,14 +183,25 @@ def test_get_pooling_config_from_args():
[
[
(
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
,
"LAST"
,
"LAST"
),
# LLM
(
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
,
"LAST"
,
"LAST"
),
# LLM
(
"intfloat/e5-small"
,
"CLS"
,
"MEAN"
),
# BertModel
(
"intfloat/e5-small"
,
"CLS"
,
"MEAN"
),
# BertModel
],
)
def
test_default_seq_pooling_type
(
model_id
,
default_pooling_type
,
pooling_type
):
model_config
=
ModelConfig
(
model_id
)
assert
model_config
.
_model_info
.
default_seq_pooling_type
==
default_pooling_type
assert
model_config
.
pooler_config
.
seq_pooling_type
==
pooling_type
@
pytest
.
mark
.
parametrize
(
(
"model_id"
,
"default_pooling_type"
,
"pooling_type"
),
[
(
"Qwen/Qwen2.5-Math-RM-72B"
,
"ALL"
,
"ALL"
),
# reward
(
"Qwen/Qwen2.5-Math-RM-72B"
,
"ALL"
,
"ALL"
),
# reward
(
"Qwen/Qwen2.5-Math-PRM-7B"
,
"STEP"
,
"STEP"
),
# step reward
(
"Qwen/Qwen2.5-Math-PRM-7B"
,
"STEP"
,
"STEP"
),
# step reward
],
],
)
)
def
test_default_pooling_type
(
model_id
,
default_pooling_type
,
pooling_type
):
def
test_default_
tok_
pooling_type
(
model_id
,
default_pooling_type
,
pooling_type
):
model_config
=
ModelConfig
(
model_id
)
model_config
=
ModelConfig
(
model_id
)
assert
model_config
.
_model_info
.
default_pooling_type
==
default_pooling_type
assert
model_config
.
_model_info
.
default_
tok_
pooling_type
==
default_pooling_type
assert
model_config
.
pooler_config
.
pooling_type
==
pooling_type
assert
model_config
.
pooler_config
.
tok_
pooling_type
==
pooling_type
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -207,8 +219,8 @@ def test_default_pooling_type(model_id, default_pooling_type, pooling_type):
...
@@ -207,8 +219,8 @@ def test_default_pooling_type(model_id, default_pooling_type, pooling_type):
)
)
def
test_moe_model_detection
(
model_id
,
expected_is_moe_model
):
def
test_moe_model_detection
(
model_id
,
expected_is_moe_model
):
model_config
=
ModelConfig
(
model_id
)
model_config
=
ModelConfig
(
model_id
)
# Just check that is_moe
_model
field exists and is a boolean
# Just check that is_moe field exists and is a boolean
assert
model_config
.
is_mo
del_moe
()
==
expected_is_moe_model
assert
model_config
.
is_mo
e
==
expected_is_moe_model
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -226,7 +238,7 @@ def test_moe_model_detection(model_id, expected_is_moe_model):
...
@@ -226,7 +238,7 @@ def test_moe_model_detection(model_id, expected_is_moe_model):
def
test_is_quantized
(
model_id
,
quantized
):
def
test_is_quantized
(
model_id
,
quantized
):
model_config
=
ModelConfig
(
model_id
)
model_config
=
ModelConfig
(
model_id
)
# Just check that quantized field exists and is a boolean
# Just check that quantized field exists and is a boolean
assert
model_config
.
is_quantized
()
==
quantized
assert
model_config
.
is_quantized
==
quantized
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
...
@@ -556,100 +568,100 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
...
@@ -556,100 +568,100 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
"jason9693/Qwen2.5-1.5B-apeach"
,
"jason9693/Qwen2.5-1.5B-apeach"
,
"decoder"
,
"decoder"
,
True
,
True
,
"Pooling models with causal attn and
last
pooling support chunked prefill."
,
"Pooling models with causal attn and
LAST/ALL
pooling support chunked prefill."
,
# noqa: E501
),
),
(
(
"Qwen/Qwen3-Embedding-0.6B"
,
"Qwen/Qwen3-Embedding-0.6B"
,
"decoder"
,
"decoder"
,
True
,
True
,
"Pooling models with causal attn and
last
pooling support chunked prefill."
,
"Pooling models with causal attn and
LAST/ALL
pooling support chunked prefill."
,
# noqa: E501
),
),
(
(
"Qwen/Qwen2.5-Math-PRM-7B"
,
"Qwen/Qwen2.5-Math-PRM-7B"
,
"decoder"
,
"decoder"
,
False
,
False
,
"Pooling models with
step
pooling do
es
not support chunked prefill."
,
"Pooling models with
causal attn and LAST/STEP
pooling do not support chunked prefill."
,
# noqa: E501
),
),
(
(
"internlm/internlm2-1_8b-reward"
,
"internlm/internlm2-1_8b-reward"
,
"decoder"
,
"decoder"
,
True
,
True
,
"Pooling models with causal attn and
all
pooling support chunked prefill."
,
"Pooling models with causal attn and
LAST/ALL
pooling support chunked prefill."
,
# noqa: E501
),
),
(
(
"BAAI/bge-base-en"
,
"BAAI/bge-base-en"
,
"encoder_only"
,
"encoder_only"
,
False
,
False
,
"Pooling models with bidirectional attn do
es
not support chunked prefill."
,
"Pooling models with bidirectional attn do not support chunked prefill."
,
# noqa: E501
),
),
(
(
"boltuix/NeuroBERT-NER"
,
"boltuix/NeuroBERT-NER"
,
"encoder_only"
,
"encoder_only"
,
False
,
False
,
"Pooling models with bidirectional attn do
es
not support chunked prefill."
,
"Pooling models with bidirectional attn do not support chunked prefill."
,
# noqa: E501
),
),
(
(
"papluca/xlm-roberta-base-language-detection"
,
"papluca/xlm-roberta-base-language-detection"
,
"encoder_only"
,
"encoder_only"
,
False
,
False
,
"Pooling models with bidirectional attn do
es
not support chunked prefill."
,
"Pooling models with bidirectional attn do not support chunked prefill."
,
# noqa: E501
),
),
(
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
,
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
,
"encoder_only"
,
"encoder_only"
,
False
,
False
,
"Pooling models with bidirectional attn do
es
not support chunked prefill."
,
"Pooling models with bidirectional attn do not support chunked prefill."
,
# noqa: E501
),
),
(
(
"intfloat/e5-small"
,
"intfloat/e5-small"
,
"encoder_only"
,
"encoder_only"
,
False
,
False
,
"Pooling models with bidirectional attn do
es
not support chunked prefill."
,
"Pooling models with bidirectional attn do not support chunked prefill."
,
# noqa: E501
),
),
# multimodal models
# multimodal models
(
(
"openai/clip-vit-base-patch32"
,
"openai/clip-vit-base-patch32"
,
"decoder"
,
"decoder"
,
True
,
True
,
"Pooling models with causal attn and
last
pooling support chunked prefill."
,
"Pooling models with causal attn and
LAST/ALL
pooling support chunked prefill."
,
# noqa: E501
),
),
(
(
"google/siglip-base-patch16-224"
,
"google/siglip-base-patch16-224"
,
"encoder_only"
,
"encoder_only"
,
False
,
False
,
"Pooling models with bidirectional attn do
es
not support chunked prefill."
,
"Pooling models with bidirectional attn do not support chunked prefill."
,
# noqa: E501
),
),
# generate models
# generate models
(
(
"Qwen/Qwen3-0.6B"
,
"Qwen/Qwen3-0.6B"
,
"decoder"
,
"decoder"
,
True
,
True
,
"Generative models support chunked prefill."
,
"Generative models support chunked prefill."
,
# noqa: E501
),
),
(
(
"Qwen/Qwen3-Next-80B-A3B-Instruct"
,
"Qwen/Qwen3-Next-80B-A3B-Instruct"
,
"hybrid"
,
"hybrid"
,
True
,
True
,
"Generative models support chunked prefill."
,
"Generative models support chunked prefill."
,
# noqa: E501
),
),
(
(
"ibm-granite/granite-4.0-h-small"
,
"ibm-granite/granite-4.0-h-small"
,
"hybrid"
,
"hybrid"
,
True
,
True
,
"Generative models support chunked prefill."
,
"Generative models support chunked prefill."
,
# noqa: E501
),
),
(
(
"state-spaces/mamba-130m-hf"
,
"state-spaces/mamba-130m-hf"
,
"attention_free"
,
"attention_free"
,
True
,
True
,
"Generative models support chunked prefill."
,
"Generative models support chunked prefill."
,
# noqa: E501
),
),
# encoder_decoder models
# encoder_decoder models
(
(
"openai/whisper-small"
,
"openai/whisper-small"
,
"encoder_decoder"
,
"encoder_decoder"
,
False
,
False
,
"Encoder decoder models do
es
not support chunked prefill."
,
"Encoder decoder models do not support chunked prefill."
,
# noqa: E501
),
),
],
],
)
)
...
@@ -675,100 +687,100 @@ def test_is_chunked_prefill_supported(
...
@@ -675,100 +687,100 @@ def test_is_chunked_prefill_supported(
"jason9693/Qwen2.5-1.5B-apeach"
,
"jason9693/Qwen2.5-1.5B-apeach"
,
"decoder"
,
"decoder"
,
True
,
True
,
"Pooling models with causal attn and
last
pooling support prefix caching."
,
"Pooling models with causal attn and
LAST/ALL
pooling support prefix caching."
,
# noqa: E501
),
),
(
(
"Qwen/Qwen3-Embedding-0.6B"
,
"Qwen/Qwen3-Embedding-0.6B"
,
"decoder"
,
"decoder"
,
True
,
True
,
"Pooling models with causal attn and
last
pooling support prefix caching."
,
"Pooling models with causal attn and
LAST/ALL
pooling support prefix caching."
,
# noqa: E501
),
),
(
(
"Qwen/Qwen2.5-Math-PRM-7B"
,
"Qwen/Qwen2.5-Math-PRM-7B"
,
"decoder"
,
"decoder"
,
False
,
False
,
"Pooling models with
step
pooling do
es
not support prefix caching."
,
"Pooling models with
causal attn and LAST/STEP
pooling do not support prefix caching."
,
# noqa: E501
),
),
(
(
"internlm/internlm2-1_8b-reward"
,
"internlm/internlm2-1_8b-reward"
,
"decoder"
,
"decoder"
,
True
,
True
,
"Pooling models with causal attn and
all
pooling support prefix caching."
,
"Pooling models with causal attn and
LAST/ALL
pooling support prefix caching."
,
# noqa: E501
),
),
(
(
"BAAI/bge-base-en"
,
"BAAI/bge-base-en"
,
"encoder_only"
,
"encoder_only"
,
False
,
False
,
"Pooling models with bidirectional attn do
es
not support prefix caching."
,
"Pooling models with bidirectional attn do not support prefix caching."
,
# noqa: E501
),
),
(
(
"boltuix/NeuroBERT-NER"
,
"boltuix/NeuroBERT-NER"
,
"encoder_only"
,
"encoder_only"
,
False
,
False
,
"Pooling models with bidirectional attn do
es
not support prefix caching."
,
"Pooling models with bidirectional attn do not support prefix caching."
,
# noqa: E501
),
),
(
(
"papluca/xlm-roberta-base-language-detection"
,
"papluca/xlm-roberta-base-language-detection"
,
"encoder_only"
,
"encoder_only"
,
False
,
False
,
"Pooling models with bidirectional attn do
es
not support prefix caching."
,
"Pooling models with bidirectional attn do not support prefix caching."
,
# noqa: E501
),
),
(
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
,
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
,
"encoder_only"
,
"encoder_only"
,
False
,
False
,
"Pooling models with bidirectional attn do
es
not support prefix caching."
,
"Pooling models with bidirectional attn do not support prefix caching."
,
# noqa: E501
),
),
(
(
"intfloat/e5-small"
,
"intfloat/e5-small"
,
"encoder_only"
,
"encoder_only"
,
False
,
False
,
"Pooling models with bidirectional attn do
es
not support prefix caching."
,
"Pooling models with bidirectional attn do not support prefix caching."
,
# noqa: E501
),
),
# multimodal models
# multimodal models
(
(
"openai/clip-vit-base-patch32"
,
"openai/clip-vit-base-patch32"
,
"decoder"
,
"decoder"
,
True
,
True
,
"Pooling models with causal attn and
last
pooling support prefix caching."
,
"Pooling models with causal attn and
LAST/ALL
pooling support prefix caching."
,
# noqa: E501
),
),
(
(
"google/siglip-base-patch16-224"
,
"google/siglip-base-patch16-224"
,
"encoder_only"
,
"encoder_only"
,
False
,
False
,
"Pooling models with bidirectional attn do
es
not support prefix caching."
,
"Pooling models with bidirectional attn do not support prefix caching."
,
# noqa: E501
),
),
# generate models
# generate models
(
(
"Qwen/Qwen3-0.6B"
,
"Qwen/Qwen3-0.6B"
,
"decoder"
,
"decoder"
,
True
,
True
,
"Generative models support prefix caching."
,
"Generative models support prefix caching."
,
# noqa: E501
),
),
(
(
"Qwen/Qwen3-Next-80B-A3B-Instruct"
,
"Qwen/Qwen3-Next-80B-A3B-Instruct"
,
"hybrid"
,
"hybrid"
,
False
,
False
,
"Hybrid models do
es
not support prefix caching since the feature is still experimental."
,
# noqa: E501
"Hybrid models do not support prefix caching since the feature is still experimental."
,
# noqa: E501
),
),
(
(
"ibm-granite/granite-4.0-h-small"
,
"ibm-granite/granite-4.0-h-small"
,
"hybrid"
,
"hybrid"
,
False
,
False
,
"Hybrid models do
es
not support prefix caching since the feature is still experimental."
,
# noqa: E501
"Hybrid models do not support prefix caching since the feature is still experimental."
,
# noqa: E501
),
),
(
(
"state-spaces/mamba-130m-hf"
,
"state-spaces/mamba-130m-hf"
,
"attention_free"
,
"attention_free"
,
False
,
False
,
"Attention free models do
es
not support prefix caching since the feature is still experimental."
,
# noqa: E501
"Attention free models do not support prefix caching since the feature is still experimental."
,
# noqa: E501
),
),
# encoder_decoder models
# encoder_decoder models
(
(
"openai/whisper-small"
,
"openai/whisper-small"
,
"encoder_decoder"
,
"encoder_decoder"
,
False
,
False
,
"Encoder decoder models do
es
not support prefix caching."
,
"Encoder decoder models do not support prefix caching."
,
# noqa: E501
),
),
],
],
)
)
...
@@ -927,7 +939,7 @@ def test_vllm_config_callable_defaults():
...
@@ -927,7 +939,7 @@ def test_vllm_config_callable_defaults():
model_config
=
quantized_model
,
optimization_level
=
OptimizationLevel
.
O2
model_config
=
quantized_model
,
optimization_level
=
OptimizationLevel
.
O2
)
)
enable_if_quantized
=
lambda
cfg
:
(
enable_if_quantized
=
lambda
cfg
:
(
cfg
.
model_config
is
not
None
and
cfg
.
model_config
.
is_quantized
()
cfg
.
model_config
is
not
None
and
cfg
.
model_config
.
is_quantized
)
)
assert
enable_if_quantized
(
config_quantized
)
is
True
assert
enable_if_quantized
(
config_quantized
)
is
True
assert
enable_if_quantized
(
config_no_model
)
is
False
assert
enable_if_quantized
(
config_no_model
)
is
False
...
@@ -938,7 +950,7 @@ def test_vllm_config_callable_defaults():
...
@@ -938,7 +950,7 @@ def test_vllm_config_callable_defaults():
model_config
=
moe_model
,
optimization_level
=
OptimizationLevel
.
O2
model_config
=
moe_model
,
optimization_level
=
OptimizationLevel
.
O2
)
)
enable_if_sequential
=
lambda
cfg
:
(
enable_if_sequential
=
lambda
cfg
:
(
cfg
.
model_config
is
not
None
and
not
cfg
.
model_config
.
is_mo
del_moe
()
cfg
.
model_config
is
not
None
and
not
cfg
.
model_config
.
is_mo
e
)
)
assert
enable_if_sequential
(
config_moe
)
is
False
assert
enable_if_sequential
(
config_moe
)
is
False
assert
enable_if_sequential
(
config_quantized
)
is
True
assert
enable_if_sequential
(
config_quantized
)
is
True
...
@@ -1052,3 +1064,46 @@ def test_scheduler_config_init():
...
@@ -1052,3 +1064,46 @@ def test_scheduler_config_init():
with
pytest
.
raises
(
AttributeError
):
with
pytest
.
raises
(
AttributeError
):
# InitVar does not become an attribute
# InitVar does not become an attribute
print
(
SchedulerConfig
.
default_factory
().
max_model_len
)
print
(
SchedulerConfig
.
default_factory
().
max_model_len
)
@
pytest
.
mark
.
parametrize
(
(
"model_id"
,
"data_parallel_size"
,
"external_lb"
,
"expected_needs_coordinator"
,
),
[
# Non-MoE model with DP=1 should not need coordinator
(
"facebook/opt-125m"
,
1
,
False
,
False
),
# Non-MoE model with DP>1 internal LB should need coordinator
(
"facebook/opt-125m"
,
2
,
False
,
True
),
# Non-MoE model with DP>1 external LB should not need coordinator
(
"facebook/opt-125m"
,
2
,
True
,
False
),
# MoE model with DP=1 should not need coordinator
(
"mistralai/Mixtral-8x7B-Instruct-v0.1"
,
1
,
False
,
False
),
# MoE model with DP>1 internal LB should need both coordinator
# and wave coordination
(
"mistralai/Mixtral-8x7B-Instruct-v0.1"
,
2
,
False
,
True
),
# MoE model with DP>1 external LB needs coordinator for wave coordination
# (wave coordination runs in coordinator process)
(
"mistralai/Mixtral-8x7B-Instruct-v0.1"
,
2
,
True
,
True
),
],
)
def
test_needs_dp_coordination
(
model_id
,
data_parallel_size
,
external_lb
,
expected_needs_coordinator
,
):
"""Test that DP coordinator and wave coordination are configured correctly."""
from
vllm.config
import
ParallelConfig
model_config
=
ModelConfig
(
model_id
)
parallel_config
=
ParallelConfig
(
data_parallel_size
=
data_parallel_size
,
data_parallel_external_lb
=
external_lb
,
)
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
parallel_config
=
parallel_config
)
assert
vllm_config
.
needs_dp_coordinator
==
expected_needs_coordinator
tests/test_pooling_params.py
View file @
7e63ef82
...
@@ -18,7 +18,7 @@ EMBEDDING_MODELS = [
...
@@ -18,7 +18,7 @@ EMBEDDING_MODELS = [
]
]
classify_parameters
=
[
"use_activation"
]
classify_parameters
=
[
"use_activation"
]
embed_parameters
=
[
"dimensions"
,
"
normalize
"
]
embed_parameters
=
[
"dimensions"
,
"
use_activation
"
]
step_pooling_parameters
=
[
"step_tag_id"
,
"returned_token_ids"
]
step_pooling_parameters
=
[
"step_tag_id"
,
"returned_token_ids"
]
...
@@ -40,19 +40,19 @@ def test_task():
...
@@ -40,19 +40,19 @@ def test_task():
def
test_embed
():
def
test_embed
():
task
=
"embed"
task
=
"embed"
model_config
=
MockModelConfig
(
pooler_config
=
PoolerConfig
(
pooling_type
=
"CLS"
))
model_config
=
MockModelConfig
(
pooler_config
=
PoolerConfig
(
seq_
pooling_type
=
"CLS"
))
pooling_params
=
PoolingParams
(
normalize
=
None
)
pooling_params
=
PoolingParams
(
use_activation
=
None
)
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
pooling_params
=
PoolingParams
(
normalize
=
True
)
pooling_params
=
PoolingParams
(
use_activation
=
True
)
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
pooling_params
=
PoolingParams
(
normalize
=
False
)
pooling_params
=
PoolingParams
(
use_activation
=
False
)
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
invalid_parameters
=
classify_parameters
+
step_pooling_parameters
invalid_parameters
=
classify_parameters
+
step_pooling_parameters
for
p
in
invalid_parameters
:
for
p
in
set
(
invalid_parameters
)
-
set
(
embed_parameters
)
:
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
pooling_params
=
PoolingParams
(
**
{
p
:
True
})
pooling_params
=
PoolingParams
(
**
{
p
:
True
})
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
...
@@ -86,7 +86,7 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
...
@@ -86,7 +86,7 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
@
pytest
.
mark
.
parametrize
(
"task"
,
[
"score"
,
"classify"
])
@
pytest
.
mark
.
parametrize
(
"task"
,
[
"score"
,
"classify"
])
def
test_classify
(
task
):
def
test_classify
(
task
):
model_config
=
MockModelConfig
(
pooler_config
=
PoolerConfig
(
pooling_type
=
"CLS"
))
model_config
=
MockModelConfig
(
pooler_config
=
PoolerConfig
(
seq_
pooling_type
=
"CLS"
))
pooling_params
=
PoolingParams
(
use_activation
=
None
)
pooling_params
=
PoolingParams
(
use_activation
=
None
)
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
...
@@ -98,7 +98,7 @@ def test_classify(task):
...
@@ -98,7 +98,7 @@ def test_classify(task):
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
invalid_parameters
=
embed_parameters
+
step_pooling_parameters
invalid_parameters
=
embed_parameters
+
step_pooling_parameters
for
p
in
invalid_parameters
:
for
p
in
set
(
invalid_parameters
)
-
set
(
classify_parameters
)
:
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
pooling_params
=
PoolingParams
(
**
{
p
:
True
})
pooling_params
=
PoolingParams
(
**
{
p
:
True
})
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
...
@@ -108,23 +108,23 @@ def test_classify(task):
...
@@ -108,23 +108,23 @@ def test_classify(task):
def
test_token_embed
(
pooling_type
:
str
):
def
test_token_embed
(
pooling_type
:
str
):
task
=
"token_embed"
task
=
"token_embed"
model_config
=
MockModelConfig
(
model_config
=
MockModelConfig
(
pooler_config
=
PoolerConfig
(
pooling_type
=
pooling_type
)
pooler_config
=
PoolerConfig
(
tok_
pooling_type
=
pooling_type
)
)
)
pooling_params
=
PoolingParams
(
normalize
=
None
)
pooling_params
=
PoolingParams
(
use_activation
=
None
)
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
pooling_params
=
PoolingParams
(
normalize
=
True
)
pooling_params
=
PoolingParams
(
use_activation
=
True
)
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
pooling_params
=
PoolingParams
(
normalize
=
False
)
pooling_params
=
PoolingParams
(
use_activation
=
False
)
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
invalid_parameters
=
classify_parameters
invalid_parameters
=
classify_parameters
if
pooling_type
!=
"STEP"
:
if
pooling_type
!=
"STEP"
:
invalid_parameters
=
classify_parameters
+
step_pooling_parameters
invalid_parameters
=
classify_parameters
+
step_pooling_parameters
for
p
in
invalid_parameters
:
for
p
in
set
(
invalid_parameters
)
-
set
(
embed_parameters
)
:
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
pooling_params
=
PoolingParams
(
**
{
p
:
True
})
pooling_params
=
PoolingParams
(
**
{
p
:
True
})
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
...
@@ -134,7 +134,7 @@ def test_token_embed(pooling_type: str):
...
@@ -134,7 +134,7 @@ def test_token_embed(pooling_type: str):
def
test_token_classify
(
pooling_type
:
str
):
def
test_token_classify
(
pooling_type
:
str
):
task
=
"token_classify"
task
=
"token_classify"
model_config
=
MockModelConfig
(
model_config
=
MockModelConfig
(
pooler_config
=
PoolerConfig
(
pooling_type
=
pooling_type
)
pooler_config
=
PoolerConfig
(
tok_
pooling_type
=
pooling_type
)
)
)
pooling_params
=
PoolingParams
(
use_activation
=
None
)
pooling_params
=
PoolingParams
(
use_activation
=
None
)
...
@@ -150,7 +150,7 @@ def test_token_classify(pooling_type: str):
...
@@ -150,7 +150,7 @@ def test_token_classify(pooling_type: str):
if
pooling_type
!=
"STEP"
:
if
pooling_type
!=
"STEP"
:
invalid_parameters
=
embed_parameters
+
step_pooling_parameters
invalid_parameters
=
embed_parameters
+
step_pooling_parameters
for
p
in
invalid_parameters
:
for
p
in
set
(
invalid_parameters
)
-
set
(
classify_parameters
)
:
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
pooling_params
=
PoolingParams
(
**
{
p
:
True
})
pooling_params
=
PoolingParams
(
**
{
p
:
True
})
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
pooling_params
.
verify
(
task
=
task
,
model_config
=
model_config
)
tests/test_routing_simulator.py
View file @
7e63ef82
...
@@ -127,7 +127,7 @@ def test_routing_strategy_integration(monkeypatch, device):
...
@@ -127,7 +127,7 @@ def test_routing_strategy_integration(monkeypatch, device):
envs
.
environment_variables
[
env_name
]
=
lambda
s
=
strategy
:
s
envs
.
environment_variables
[
env_name
]
=
lambda
s
=
strategy
:
s
# Test the select_experts method
# Test the select_experts method
topk_weights
,
topk_ids
,
_
=
fused_moe
.
select_experts
(
topk_weights
,
topk_ids
=
fused_moe
.
router
.
select_experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
)
)
...
...
tests/tokenizers_/test_basic.py
View file @
7e63ef82
...
@@ -10,6 +10,7 @@ from transformers import (
...
@@ -10,6 +10,7 @@ from transformers import (
)
)
from
vllm.tokenizers
import
TokenizerLike
,
get_tokenizer
from
vllm.tokenizers
import
TokenizerLike
,
get_tokenizer
from
vllm.tokenizers.grok2
import
Grok2Tokenizer
from
vllm.tokenizers.mistral
import
MistralTokenizer
from
vllm.tokenizers.mistral
import
MistralTokenizer
...
@@ -37,6 +38,10 @@ def test_tokenizer_like_protocol():
...
@@ -37,6 +38,10 @@ def test_tokenizer_like_protocol():
assert
isinstance
(
tokenizer
,
MistralTokenizer
)
assert
isinstance
(
tokenizer
,
MistralTokenizer
)
_assert_tokenizer_like
(
tokenizer
)
_assert_tokenizer_like
(
tokenizer
)
tokenizer
=
get_tokenizer
(
"xai-org/grok-2"
,
tokenizer_mode
=
"grok2"
)
assert
isinstance
(
tokenizer
,
Grok2Tokenizer
)
_assert_tokenizer_like
(
tokenizer
)
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
[
"facebook/opt-125m"
,
"gpt2"
])
@
pytest
.
mark
.
parametrize
(
"tokenizer_name"
,
[
"facebook/opt-125m"
,
"gpt2"
])
def
test_tokenizer_revision
(
tokenizer_name
:
str
):
def
test_tokenizer_revision
(
tokenizer_name
:
str
):
...
...
tests/tokenizers_/test_detokenize.py
View file @
7e63ef82
...
@@ -40,7 +40,8 @@ TOKENIZERS = [
...
@@ -40,7 +40,8 @@ TOKENIZERS = [
os
.
path
.
join
(
models_path_prefix
,
"EleutherAI/gpt-j-6b"
),
os
.
path
.
join
(
models_path_prefix
,
"EleutherAI/gpt-j-6b"
),
os
.
path
.
join
(
models_path_prefix
,
"EleutherAI/pythia-70m"
),
os
.
path
.
join
(
models_path_prefix
,
"EleutherAI/pythia-70m"
),
os
.
path
.
join
(
models_path_prefix
,
"bigscience/bloom-560m"
),
os
.
path
.
join
(
models_path_prefix
,
"bigscience/bloom-560m"
),
os
.
path
.
join
(
models_path_prefix
,
"mosaicml/mpt-7b"
),
# FIXME: mosaicml/mpt-7b has been deleted
# "mosaicml/mpt-7b",
os
.
path
.
join
(
models_path_prefix
,
"tiiuae/falcon-7b"
),
os
.
path
.
join
(
models_path_prefix
,
"tiiuae/falcon-7b"
),
os
.
path
.
join
(
models_path_prefix
,
"meta-llama/Llama-3.2-1B-Instruct"
),
os
.
path
.
join
(
models_path_prefix
,
"meta-llama/Llama-3.2-1B-Instruct"
),
os
.
path
.
join
(
models_path_prefix
,
"codellama/CodeLlama-7b-hf"
),
os
.
path
.
join
(
models_path_prefix
,
"codellama/CodeLlama-7b-hf"
),
...
...
tests/tool_parsers/test_functiongemma_tool_parser.py
0 → 100644
View file @
7e63ef82
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
unittest.mock
import
MagicMock
import
pytest
from
vllm.entrypoints.openai.protocol
import
ChatCompletionRequest
from
vllm.tool_parsers.functiongemma_tool_parser
import
FunctionGemmaToolParser
@
pytest
.
fixture
def
mock_tokenizer
():
tokenizer
=
MagicMock
()
tokenizer
.
encode
.
return_value
=
[
1
,
2
,
3
]
tokenizer
.
get_vocab
.
return_value
=
{}
return
tokenizer
@
pytest
.
fixture
def
parser
(
mock_tokenizer
):
return
FunctionGemmaToolParser
(
mock_tokenizer
)
@
pytest
.
fixture
def
mock_request
():
request
=
MagicMock
(
spec
=
ChatCompletionRequest
)
request
.
tools
=
[]
request
.
tool_choice
=
"auto"
return
request
class
TestExtractToolCalls
:
def
test_no_tool_calls
(
self
,
parser
,
mock_request
):
model_output
=
"Hello, how can I help you today?"
result
=
parser
.
extract_tool_calls
(
model_output
,
mock_request
)
assert
result
.
tools_called
is
False
assert
result
.
tool_calls
==
[]
assert
result
.
content
==
model_output
def
test_single_tool_call
(
self
,
parser
,
mock_request
):
model_output
=
(
"<start_function_call>call:get_weather{location:<escape>London<escape>}"
"<end_function_call>"
)
result
=
parser
.
extract_tool_calls
(
model_output
,
mock_request
)
assert
result
.
tools_called
is
True
assert
len
(
result
.
tool_calls
)
==
1
assert
result
.
tool_calls
[
0
].
function
.
name
==
"get_weather"
assert
'"location": "London"'
in
result
.
tool_calls
[
0
].
function
.
arguments
def
test_multiple_arguments
(
self
,
parser
,
mock_request
):
model_output
=
(
"<start_function_call>call:get_weather{"
"location:<escape>San Francisco<escape>,"
"unit:<escape>celsius<escape>}"
"<end_function_call>"
)
result
=
parser
.
extract_tool_calls
(
model_output
,
mock_request
)
assert
result
.
tools_called
is
True
assert
len
(
result
.
tool_calls
)
==
1
assert
result
.
tool_calls
[
0
].
function
.
name
==
"get_weather"
args
=
result
.
tool_calls
[
0
].
function
.
arguments
assert
"San Francisco"
in
args
assert
"celsius"
in
args
def
test_text_before_tool_call
(
self
,
parser
,
mock_request
):
model_output
=
(
"Let me check the weather for you. "
"<start_function_call>call:get_weather{location:<escape>Paris<escape>}"
"<end_function_call>"
)
result
=
parser
.
extract_tool_calls
(
model_output
,
mock_request
)
assert
result
.
tools_called
is
True
assert
result
.
content
==
"Let me check the weather for you."
def
test_multiple_tool_calls
(
self
,
parser
,
mock_request
):
model_output
=
(
"<start_function_call>call:get_weather{location:<escape>London<escape>}"
"<end_function_call>"
"<start_function_call>call:get_time{timezone:<escape>UTC<escape>}"
"<end_function_call>"
)
result
=
parser
.
extract_tool_calls
(
model_output
,
mock_request
)
assert
result
.
tools_called
is
True
assert
len
(
result
.
tool_calls
)
==
2
assert
result
.
tool_calls
[
0
].
function
.
name
==
"get_weather"
assert
result
.
tool_calls
[
1
].
function
.
name
==
"get_time"
class
TestParseArguments
:
def
test_empty_arguments
(
self
,
parser
):
result
=
parser
.
_parse_arguments
(
""
)
assert
result
==
{}
def
test_single_string_argument
(
self
,
parser
):
result
=
parser
.
_parse_arguments
(
"city:<escape>Tokyo<escape>"
)
assert
result
==
{
"city"
:
"Tokyo"
}
def
test_multiple_arguments
(
self
,
parser
):
args_str
=
"city:<escape>Tokyo<escape>,country:<escape>Japan<escape>"
result
=
parser
.
_parse_arguments
(
args_str
)
assert
result
==
{
"city"
:
"Tokyo"
,
"country"
:
"Japan"
}
def
test_numeric_argument
(
self
,
parser
):
result
=
parser
.
_parse_arguments
(
"count:<escape>42<escape>"
)
assert
result
==
{
"count"
:
42
}
def
test_boolean_argument
(
self
,
parser
):
result
=
parser
.
_parse_arguments
(
"enabled:<escape>true<escape>"
)
assert
result
==
{
"enabled"
:
True
}
def
test_argument_with_spaces
(
self
,
parser
):
result
=
parser
.
_parse_arguments
(
"message:<escape>Hello World<escape>"
)
assert
result
==
{
"message"
:
"Hello World"
}
class
TestAdjustRequest
:
def
test_skip_special_tokens_disabled
(
self
,
parser
,
mock_request
):
mock_request
.
tools
=
[{
"type"
:
"function"
,
"function"
:
{
"name"
:
"test"
}}]
mock_request
.
tool_choice
=
"auto"
mock_request
.
skip_special_tokens
=
True
result
=
parser
.
adjust_request
(
mock_request
)
assert
result
.
skip_special_tokens
is
False
def
test_skip_special_tokens_when_tool_choice_none
(
self
,
parser
,
mock_request
):
mock_request
.
tools
=
[{
"type"
:
"function"
,
"function"
:
{
"name"
:
"test"
}}]
mock_request
.
tool_choice
=
"none"
mock_request
.
skip_special_tokens
=
True
result
=
parser
.
adjust_request
(
mock_request
)
assert
result
.
skip_special_tokens
is
True
class
TestBufferDeltaText
:
def
test_regular_text_not_buffered
(
self
,
parser
):
result
=
parser
.
_buffer_delta_text
(
"hello"
)
assert
result
==
"hello"
assert
parser
.
buffered_delta_text
==
""
def
test_complete_tag_flushed
(
self
,
parser
):
parser
.
buffered_delta_text
=
"<start_function_"
result
=
parser
.
_buffer_delta_text
(
"call>"
)
assert
"<start_function_call>"
in
result
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
,
"-v"
])
tests/tool_parsers/test_kimi_k2_tool_parser.py
View file @
7e63ef82
...
@@ -44,6 +44,33 @@ def assert_tool_calls(
...
@@ -44,6 +44,33 @@ def assert_tool_calls(
)
)
def
run_streaming_sequence
(
parser
,
deltas
):
"""Helper to simulate a streaming sequence and return results."""
previous_text
=
""
previous_token_ids
:
list
[
int
]
=
[]
results
=
[]
for
delta_text
,
delta_token_ids
in
deltas
:
current_text
=
previous_text
+
delta_text
current_token_ids
=
previous_token_ids
+
delta_token_ids
result
=
parser
.
extract_tool_calls_streaming
(
previous_text
=
previous_text
,
current_text
=
current_text
,
delta_text
=
delta_text
,
previous_token_ids
=
previous_token_ids
,
current_token_ids
=
current_token_ids
,
delta_token_ids
=
delta_token_ids
,
request
=
None
,
)
results
.
append
(
result
)
previous_text
=
current_text
previous_token_ids
=
current_token_ids
return
results
def
test_extract_tool_calls_no_tools
(
kimi_k2_tool_parser
):
def
test_extract_tool_calls_no_tools
(
kimi_k2_tool_parser
):
model_output
=
"This is a test"
model_output
=
"This is a test"
extracted_tool_calls
=
kimi_k2_tool_parser
.
extract_tool_calls
(
extracted_tool_calls
=
kimi_k2_tool_parser
.
extract_tool_calls
(
...
@@ -346,61 +373,32 @@ def test_token_leak_between_section_and_tool_begin(kimi_k2_tool_parser):
...
@@ -346,61 +373,32 @@ def test_token_leak_between_section_and_tool_begin(kimi_k2_tool_parser):
tool_call_begin_token_id
=
kimi_k2_tool_parser
.
vocab
.
get
(
"<|tool_call_begin|>"
)
tool_call_begin_token_id
=
kimi_k2_tool_parser
.
vocab
.
get
(
"<|tool_call_begin|>"
)
# Simulate streaming sequence:
# Simulate streaming sequence:
deltas
=
[
(
"I'll help you with that. "
,
[
1
,
2
,
3
]),
(
"<|tool_calls_section_begin|>"
,
[
section_begin_token_id
]),
(
" spurious text "
,
[
4
,
5
]),
(
"<|tool_call_begin|>"
,
[
tool_call_begin_token_id
]),
]
results
=
run_streaming_sequence
(
kimi_k2_tool_parser
,
deltas
)
# Delta 1: "I'll help you with that. "
# Delta 1: "I'll help you with that. "
result1
=
kimi_k2_tool_parser
.
extract_tool_calls_streaming
(
assert
results
[
0
]
is
not
None
previous_text
=
""
,
assert
results
[
0
].
content
==
"I'll help you with that. "
current_text
=
"I'll help you with that. "
,
delta_text
=
"I'll help you with that. "
,
previous_token_ids
=
[],
current_token_ids
=
[
1
,
2
,
3
],
# Regular tokens
delta_token_ids
=
[
1
,
2
,
3
],
request
=
None
,
)
assert
result1
is
not
None
assert
result1
.
content
==
"I'll help you with that. "
# Delta 2: "<|tool_calls_section_begin|>"
# Delta 2: "<|tool_calls_section_begin|>"
prev_ids
=
[
1
,
2
,
3
]
curr_ids
=
prev_ids
+
[
section_begin_token_id
]
result2
=
kimi_k2_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
"I'll help you with that. "
,
current_text
=
"I'll help you with that. <|tool_calls_section_begin|>"
,
delta_text
=
"<|tool_calls_section_begin|>"
,
previous_token_ids
=
prev_ids
,
current_token_ids
=
curr_ids
,
delta_token_ids
=
[
section_begin_token_id
],
request
=
None
,
)
# Section marker should be stripped and suppressed
# Section marker should be stripped and suppressed
assert
result2
is
None
or
(
result2
.
content
is
None
or
result2
.
content
==
""
)
assert
results
[
1
]
is
None
or
(
results
[
1
].
content
is
None
or
results
[
1
].
content
==
""
)
# Delta 3: " spurious text or tokens " (THE LEAK SCENARIO)
# Delta 3: " spurious text or tokens " (THE LEAK SCENARIO)
prev_ids
=
curr_ids
curr_ids
=
curr_ids
+
[
4
,
5
]
result3
=
kimi_k2_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
"I'll help you with that. <|tool_calls_section_begin|>"
,
current_text
=
"I'll help you with that. <|tool_calls_section_begin|> spurious text "
,
delta_text
=
" spurious text "
,
previous_token_ids
=
prev_ids
,
current_token_ids
=
curr_ids
,
delta_token_ids
=
[
4
,
5
],
request
=
None
,
)
# CRITICAL: This text should be suppressed, NOT returned as reasoning_delta
# CRITICAL: This text should be suppressed, NOT returned as reasoning_delta
assert
result3
is
None
or
(
result3
.
content
is
None
or
result3
.
content
==
""
)
assert
results
[
2
]
is
None
or
(
results
[
2
].
content
is
None
or
results
[
2
].
content
==
""
)
# Delta 4: "<|tool_call_begin|>..."
# Delta 4: "<|tool_call_begin|>..."
prev_ids
=
curr_ids
curr_ids
=
curr_ids
+
[
tool_call_begin_token_id
]
_result4
=
kimi_k2_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
"I'll help you with that. <|tool_calls_section_begin|> spurious text "
,
current_text
=
"I'll help you with that. <|tool_calls_section_begin|> spurious text <|tool_call_begin|>"
,
delta_text
=
"<|tool_call_begin|>"
,
previous_token_ids
=
prev_ids
,
current_token_ids
=
curr_ids
,
delta_token_ids
=
[
tool_call_begin_token_id
],
request
=
None
,
)
# Now we're in tool call mode, result depends on internal state
# Now we're in tool call mode, result depends on internal state
# The key is that the spurious text from Delta 3 was not leaked
# The key is that the spurious text from Delta 3 was not leaked
...
@@ -416,31 +414,15 @@ def test_split_markers_across_deltas(kimi_k2_tool_parser):
...
@@ -416,31 +414,15 @@ def test_split_markers_across_deltas(kimi_k2_tool_parser):
"<|tool_calls_section_begin|>"
"<|tool_calls_section_begin|>"
)
)
# Delta 1: "...reasoning<|tool_calls_sec"
# Delta 1: partial token, Delta 2: complete marker
_result1
=
kimi_k2_tool_parser
.
extract_tool_calls_streaming
(
deltas
=
[
previous_text
=
"Some reasoning"
,
(
"<|tool_calls_sec"
,
[
3
]),
current_text
=
"Some reasoning<|tool_calls_sec"
,
(
"tion_begin|> "
,
[
section_begin_token_id
,
4
]),
delta_text
=
"<|tool_calls_sec"
,
]
previous_token_ids
=
[
1
,
2
],
current_token_ids
=
[
1
,
2
,
3
],
# Partial token
_results
=
run_streaming_sequence
(
kimi_k2_tool_parser
,
deltas
)
delta_token_ids
=
[
3
],
request
=
None
,
)
# Partial token not recognized yet, might be buffered
# Should return as content or None (depends on implementation)
# Delta 2: "tion_begin|> " (completes the marker)
_result2
=
kimi_k2_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
"Some reasoning<|tool_calls_sec"
,
current_text
=
"Some reasoning<|tool_calls_section_begin|> "
,
delta_text
=
"tion_begin|> "
,
previous_token_ids
=
[
1
,
2
,
3
],
current_token_ids
=
[
1
,
2
,
section_begin_token_id
,
4
],
delta_token_ids
=
[
section_begin_token_id
,
4
],
request
=
None
,
)
# Now the complete marker should be detected via buffer
# Now the complete marker should be detected via buffer
# The parser should enter tool section mode
assert
kimi_k2_tool_parser
.
in_tool_section
is
True
assert
kimi_k2_tool_parser
.
in_tool_section
is
True
...
@@ -475,42 +457,17 @@ def test_reentry_to_reasoning_after_tool_section(kimi_k2_tool_parser):
...
@@ -475,42 +457,17 @@ def test_reentry_to_reasoning_after_tool_section(kimi_k2_tool_parser):
section_begin_id
=
kimi_k2_tool_parser
.
vocab
.
get
(
"<|tool_calls_section_begin|>"
)
section_begin_id
=
kimi_k2_tool_parser
.
vocab
.
get
(
"<|tool_calls_section_begin|>"
)
section_end_id
=
kimi_k2_tool_parser
.
vocab
.
get
(
"<|tool_calls_section_end|>"
)
section_end_id
=
kimi_k2_tool_parser
.
vocab
.
get
(
"<|tool_calls_section_end|>"
)
# Enter tool section
deltas
=
[
_result1
=
kimi_k2_tool_parser
.
extract_tool_calls_streaming
(
(
"<|tool_calls_section_begin|>"
,
[
section_begin_id
]),
previous_text
=
""
,
(
"<|tool_calls_section_end|>"
,
[
section_end_id
]),
current_text
=
"<|tool_calls_section_begin|>"
,
(
" More reasoning"
,
[
10
,
11
]),
delta_text
=
"<|tool_calls_section_begin|>"
,
]
previous_token_ids
=
[],
current_token_ids
=
[
section_begin_id
],
delta_token_ids
=
[
section_begin_id
],
request
=
None
,
)
assert
kimi_k2_tool_parser
.
in_tool_section
is
True
# Exit tool section
results
=
run_streaming_sequence
(
kimi_k2_tool_parser
,
deltas
)
_result2
=
kimi_k2_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
"<|tool_calls_section_begin|>"
,
current_text
=
"<|tool_calls_section_begin|><|tool_calls_section_end|>"
,
delta_text
=
"<|tool_calls_section_end|>"
,
previous_token_ids
=
[
section_begin_id
],
current_token_ids
=
[
section_begin_id
,
section_end_id
],
delta_token_ids
=
[
section_end_id
],
request
=
None
,
)
assert
kimi_k2_tool_parser
.
in_tool_section
is
False
# Subsequent reasoning text should be returned normally
assert
kimi_k2_tool_parser
.
in_tool_section
is
False
result3
=
kimi_k2_tool_parser
.
extract_tool_calls_streaming
(
assert
results
[
2
]
is
not
None
previous_text
=
"<|tool_calls_section_begin|><|tool_calls_section_end|>"
,
assert
results
[
2
].
content
==
" More reasoning"
current_text
=
"<|tool_calls_section_begin|><|tool_calls_section_end|> More reasoning"
,
delta_text
=
" More reasoning"
,
previous_token_ids
=
[
section_begin_id
,
section_end_id
],
current_token_ids
=
[
section_begin_id
,
section_end_id
,
10
,
11
],
delta_token_ids
=
[
10
,
11
],
request
=
None
,
)
assert
result3
is
not
None
assert
result3
.
content
==
" More reasoning"
def
test_empty_tool_section
(
kimi_k2_tool_parser
):
def
test_empty_tool_section
(
kimi_k2_tool_parser
):
...
@@ -819,106 +776,150 @@ def test_tool_call_end_and_section_end_same_chunk(kimi_k2_tool_parser):
...
@@ -819,106 +776,150 @@ def test_tool_call_end_and_section_end_same_chunk(kimi_k2_tool_parser):
tool_end_id
=
kimi_k2_tool_parser
.
vocab
.
get
(
"<|tool_call_end|>"
)
tool_end_id
=
kimi_k2_tool_parser
.
vocab
.
get
(
"<|tool_call_end|>"
)
# Simulate a streaming sequence for a SHORT tool call (all in one chunk):
# Simulate a streaming sequence for a SHORT tool call (all in one chunk):
# 1. Reasoning text
result1
=
kimi_k2_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
""
,
current_text
=
"Let me help. "
,
delta_text
=
"Let me help. "
,
previous_token_ids
=
[],
current_token_ids
=
[
1
,
2
],
delta_token_ids
=
[
1
,
2
],
request
=
None
,
)
assert
result1
is
not
None
assert
result1
.
content
==
"Let me help. "
# 2. Section begin
_result2
=
kimi_k2_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
"Let me help. "
,
current_text
=
"Let me help. <|tool_calls_section_begin|>"
,
delta_text
=
"<|tool_calls_section_begin|>"
,
previous_token_ids
=
[
1
,
2
],
current_token_ids
=
[
1
,
2
,
section_begin_id
],
delta_token_ids
=
[
section_begin_id
],
request
=
None
,
)
assert
kimi_k2_tool_parser
.
in_tool_section
is
True
# 3. Tool call begin + full content + tool_end + section_end ALL IN ONE CHUNK
# This is the critical scenario for short tool calls
combined
=
(
combined
=
(
'<|tool_call_begin|>get_weather:0 <|tool_call_argument_begin|> {"city": "Paris"} '
'<|tool_call_begin|>get_weather:0 <|tool_call_argument_begin|> {"city": "Paris"} '
"<|tool_call_end|><|tool_calls_section_end|>"
"<|tool_call_end|><|tool_calls_section_end|>"
)
)
# Build up the previous text gradually to simulate realistic streaming
deltas
=
[
prev_text
=
"Let me help. <|tool_calls_section_begin|>"
(
"Let me help. "
,
[
1
,
2
]),
curr_text
=
prev_text
+
combined
(
"<|tool_calls_section_begin|>"
,
[
section_begin_id
]),
(
combined
,
[
tool_begin_id
,
10
,
11
,
12
,
tool_end_id
,
section_end_id
]),
(
" Done"
,
[
20
]),
]
result3
=
kimi_k2_tool_parser
.
extract_tool_calls_streaming
(
results
=
run_streaming_sequence
(
kimi_k2_tool_parser
,
deltas
)
previous_text
=
prev_text
,
current_text
=
curr_text
,
delta_text
=
combined
,
previous_token_ids
=
[
1
,
2
,
section_begin_id
],
current_token_ids
=
[
1
,
2
,
section_begin_id
,
tool_begin_id
,
10
,
11
,
12
,
tool_end_id
,
section_end_id
,
],
delta_token_ids
=
[
tool_begin_id
,
10
,
11
,
12
,
tool_end_id
,
section_end_id
],
request
=
None
,
)
# CRITICAL: Parser should have exited section AFTER processing tool
# CRITICAL: Parser should have exited section AFTER processing tool
assert
kimi_k2_tool_parser
.
in_tool_section
is
False
assert
kimi_k2_tool_parser
.
in_tool_section
is
False
# Tool call should have been emitted (not dropped)
# Tool call should have been emitted (not dropped)
# The result might be the tool name or None depending on state, but
if
results
[
2
]
is
not
None
and
results
[
2
].
content
is
not
None
:
# importantly, it shouldn't be returning the literal tokens as content
if
result3
is
not
None
and
result3
.
content
is
not
None
:
# Verify no special tokens leaked into content
# Verify no special tokens leaked into content
assert
"<|tool_call_end|>"
not
in
result
3
.
content
assert
"<|tool_call_end|>"
not
in
result
s
[
2
]
.
content
assert
"<|tool_calls_section_end|>"
not
in
result
3
.
content
assert
"<|tool_calls_section_end|>"
not
in
result
s
[
2
]
.
content
# 4. Verify subsequent content streams normally
# Content after tool section should stream normally
result4
=
kimi_k2_tool_parser
.
extract_tool_calls_streaming
(
assert
results
[
3
]
is
not
None
previous_text
=
curr_text
,
assert
results
[
3
].
content
==
" Done"
current_text
=
curr_text
+
" Done"
,
delta_text
=
" Done"
,
previous_token_ids
=
[
def
test_streaming_tool_call_markers_not_leaked
(
kimi_k2_tool_parser
):
1
,
"""
2
,
CRITICAL TEST: Verify that tool call markers (<|tool_call_begin|>,
section_begin_id
,
<|tool_call_end|>, <|tool_call_argument_begin|>) are NOT leaked
tool_begin_id
,
into the content field during streaming.
10
,
11
,
This reproduces the AWS Bedrock bug where tool call markers appeared
12
,
in the 'text' field of responses.
tool_end_id
,
"""
section_end_id
,
kimi_k2_tool_parser
.
reset_streaming_state
()
],
current_token_ids
=
[
section_begin_id
=
kimi_k2_tool_parser
.
vocab
.
get
(
"<|tool_calls_section_begin|>"
)
1
,
section_end_id
=
kimi_k2_tool_parser
.
vocab
.
get
(
"<|tool_calls_section_end|>"
)
2
,
tool_begin_id
=
kimi_k2_tool_parser
.
vocab
.
get
(
"<|tool_call_begin|>"
)
section_begin_id
,
tool_end_id
=
kimi_k2_tool_parser
.
vocab
.
get
(
"<|tool_call_end|>"
)
tool_begin_id
,
10
,
# List of markers that should NEVER appear in content
11
,
forbidden_markers
=
[
12
,
"<|tool_call_begin|>"
,
tool_end_id
,
"<|tool_call_end|>"
,
section_end_id
,
"<|tool_call_argument_begin|>"
,
20
,
"<|tool_calls_section_begin|>"
,
],
"<|tool_calls_section_end|>"
,
delta_token_ids
=
[
20
],
]
request
=
None
,
all_content
=
[]
# Steps: reasoning, section begin, tool call, section end, more reasoning
tool_chunk
=
(
"<|tool_call_begin|> functions.get_weather:0 "
'<|tool_call_argument_begin|> {"city": "Tokyo"} <|tool_call_end|>'
)
)
deltas
=
[
(
"I'll check the weather. "
,
[
1
,
2
,
3
]),
(
"<|tool_calls_section_begin|>"
,
[
section_begin_id
]),
(
tool_chunk
,
[
tool_begin_id
,
10
,
11
,
tool_end_id
]),
(
"<|tool_calls_section_end|>"
,
[
section_end_id
]),
(
" Here's the result."
,
[
20
,
21
]),
]
results
=
run_streaming_sequence
(
kimi_k2_tool_parser
,
deltas
)
for
res
in
results
:
if
res
and
res
.
content
:
all_content
.
append
(
res
.
content
)
# CRITICAL ASSERTIONS: No forbidden markers in any content
full_content
=
""
.
join
(
all_content
)
for
marker
in
forbidden_markers
:
assert
marker
not
in
full_content
,
(
f
"MARKER LEAK DETECTED: '
{
marker
}
' found in content. "
f
"Full content:
{
repr
(
full_content
)
}
"
)
# Content after tool section should stream normally
# Also check that tool call content (function name, arguments) is not leaked
assert
result4
is
not
None
assert
"get_weather"
not
in
full_content
,
(
assert
result4
.
content
==
" Done"
f
"TOOL CALL CONTENT LEAKED: 'get_weather' found in content. "
f
"Full content:
{
repr
(
full_content
)
}
"
)
assert
"Tokyo"
not
in
full_content
,
(
f
"TOOL CALL CONTENT LEAKED: 'Tokyo' found in content. "
f
"Full content:
{
repr
(
full_content
)
}
"
)
# Verify that legitimate content was preserved
assert
"I'll check the weather."
in
full_content
or
len
(
all_content
)
>
0
def
test_streaming_multiple_tool_calls_not_leaked
(
kimi_k2_tool_parser
):
"""
Test that MULTIPLE tool calls in streaming mode do not leak into content.
This reproduces the AWS Bedrock scenario: "Compare weather in Tokyo and NYC".
"""
kimi_k2_tool_parser
.
reset_streaming_state
()
section_begin_id
=
kimi_k2_tool_parser
.
vocab
.
get
(
"<|tool_calls_section_begin|>"
)
section_end_id
=
kimi_k2_tool_parser
.
vocab
.
get
(
"<|tool_calls_section_end|>"
)
tool_begin_id
=
kimi_k2_tool_parser
.
vocab
.
get
(
"<|tool_call_begin|>"
)
tool_end_id
=
kimi_k2_tool_parser
.
vocab
.
get
(
"<|tool_call_end|>"
)
all_content
=
[]
tool1
=
'<|tool_call_begin|> get_weather:0 <|tool_call_argument_begin|> {"city": "Tokyo"} <|tool_call_end|>'
tool2
=
' <|tool_call_begin|> get_weather:1 <|tool_call_argument_begin|> {"city": "New York"} <|tool_call_end|>'
deltas
=
[
(
"I'll compare the weather. "
,
[
1
,
2
,
3
]),
(
"<|tool_calls_section_begin|>"
,
[
section_begin_id
]),
(
tool1
,
[
tool_begin_id
,
10
,
tool_end_id
]),
(
tool2
,
[
tool_begin_id
,
20
,
tool_end_id
]),
(
"<|tool_calls_section_end|>"
,
[
section_end_id
]),
(
" Here's the comparison."
,
[
30
]),
]
results
=
run_streaming_sequence
(
kimi_k2_tool_parser
,
deltas
)
for
res
in
results
:
if
res
and
res
.
content
:
all_content
.
append
(
res
.
content
)
# Assertions
full_content
=
""
.
join
(
all_content
)
# Check no markers leaked
forbidden
=
[
"<|tool_call"
,
"<|tool_calls_section"
]
for
marker
in
forbidden
:
assert
marker
not
in
full_content
,
(
f
"MARKER LEAKED:
{
marker
}
in
{
repr
(
full_content
)
}
"
)
# Check no tool call content leaked (both tools)
assert
"get_weather"
not
in
full_content
,
f
"TOOL NAME LEAKED:
{
repr
(
full_content
)
}
"
assert
"Tokyo"
not
in
full_content
,
f
"TOOL ARG LEAKED (Tokyo):
{
repr
(
full_content
)
}
"
assert
"New York"
not
in
full_content
,
(
f
"TOOL ARG LEAKED (NYC):
{
repr
(
full_content
)
}
"
)
# Legitimate content preserved
assert
"compare"
in
full_content
.
lower
()
or
len
(
all_content
)
>
0
tests/tool_parsers/test_mistral_tool_parser.py
View file @
7e63ef82
...
@@ -281,6 +281,8 @@ def test_extract_tool_calls_pre_v11_tokenizer(
...
@@ -281,6 +281,8 @@ def test_extract_tool_calls_pre_v11_tokenizer(
"single_tool_add"
,
"single_tool_add"
,
"single_tool_weather"
,
"single_tool_weather"
,
"multiple_tool_calls"
,
"multiple_tool_calls"
,
"complex"
,
"wrong_json"
,
],
],
argnames
=
[
"model_output"
,
"expected_tool_calls"
,
"expected_content"
],
argnames
=
[
"model_output"
,
"expected_tool_calls"
,
"expected_content"
],
argvalues
=
[
argvalues
=
[
...
@@ -326,6 +328,36 @@ def test_extract_tool_calls_pre_v11_tokenizer(
...
@@ -326,6 +328,36 @@ def test_extract_tool_calls_pre_v11_tokenizer(
],
],
None
,
None
,
),
),
(
# Complex
"""hi{hi[TOOL_CALLS]bash{"command": "print(
\\
"hello world!
\\
")
\\
nre.compile(r
\'
{}
\'
)"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"bash"
,
arguments
=
json
.
dumps
(
{
"command"
:
"print(
\"
hello world!
\"
)
\n
re.compile(r'{}')"
}
)[:
-
2
],
)
)
],
"hi{hi"
,
),
(
# Wrong json
"""hi{hi[TOOL_CALLS]bash{"command": "print(
\\
"hello world!
\\
")
\\
nre.compile(r
\'
{}
\'
)"}"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"bash"
,
arguments
=
json
.
dumps
(
{
"command"
:
"print(
\"
hello world!
\"
)
\n
re.compile(r'{}')"
}
),
)
)
],
"hi{hi"
,
),
],
],
)
)
def
test_extract_tool_calls
(
def
test_extract_tool_calls
(
...
@@ -673,7 +705,7 @@ def test_extract_tool_calls_streaming(
...
@@ -673,7 +705,7 @@ def test_extract_tool_calls_streaming(
),
),
(
(
# Complex
# Complex
"""[TOOL_CALLS]bash{"command": "print(
\\
"hello world!
\\
")
\\
nre.compile(r
\'
{}
\'
)"}"""
,
# noqa: E501
"""
hi{hi
[TOOL_CALLS]bash{"command": "print(
\\
"hello world!
\\
")
\\
nre.compile(r
\'
{}
\'
)"}"""
,
# noqa: E501
[
[
ToolCall
(
ToolCall
(
function
=
FunctionCall
(
function
=
FunctionCall
(
...
@@ -684,7 +716,7 @@ def test_extract_tool_calls_streaming(
...
@@ -684,7 +716,7 @@ def test_extract_tool_calls_streaming(
)
)
)
)
],
],
""
,
"
hi{hi
"
,
),
),
],
],
)
)
...
...
tests/tool_use/test_chat_completions.py
View file @
7e63ef82
...
@@ -151,3 +151,45 @@ async def test_chat_completion_with_tools(
...
@@ -151,3 +151,45 @@ async def test_chat_completion_with_tools(
assert
chunk
.
choices
[
0
].
finish_reason
!=
"tool_calls"
assert
chunk
.
choices
[
0
].
finish_reason
!=
"tool_calls"
assert
len
(
chunks
)
assert
len
(
chunks
)
assert
""
.
join
(
chunks
)
==
output_text
assert
""
.
join
(
chunks
)
==
output_text
# Regression test for https://github.com/vllm-project/vllm/issues/32006
# Engine crash when combining response_format: json_object with
# tool_choice: required
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
timeout
(
120
)
async
def
test_response_format_with_tool_choice_required
(
client
:
openai
.
AsyncOpenAI
,
server_config
:
ServerConfig
):
"""
Test that combining response_format: json_object with tool_choice: required
doesn't crash the engine.
Before the fix, this would cause a validation error:
"You can only use one kind of structured outputs constraint but multiple
are specified" because both json_object and json (from tool schema) would
be set in StructuredOutputsParams.
"""
models
=
await
client
.
models
.
list
()
model_name
:
str
=
models
.
data
[
0
].
id
# This combination previously crashed the engine
chat_completion
=
await
client
.
chat
.
completions
.
create
(
messages
=
ensure_system_prompt
(
[{
"role"
:
"user"
,
"content"
:
"What is the weather in Dallas, Texas?"
}],
server_config
,
),
temperature
=
0
,
max_completion_tokens
=
150
,
model
=
model_name
,
tools
=
[
WEATHER_TOOL
],
tool_choice
=
"required"
,
response_format
=
{
"type"
:
"json_object"
},
)
# The fix clears response_format when tool_choice forces tool calling,
# so the request should complete successfully with tool calls
choice
=
chat_completion
.
choices
[
0
]
assert
choice
.
finish_reason
==
"tool_calls"
assert
choice
.
message
.
tool_calls
is
not
None
assert
len
(
choice
.
message
.
tool_calls
)
>
0
tests/tool_use/test_minimax_m2_tool_parser.py
0 → 100644
View file @
7e63ef82
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
import
pytest
from
vllm.tool_parsers.minimax_m2_tool_parser
import
(
MinimaxM2ToolParser
,
)
pytestmark
=
pytest
.
mark
.
cpu_test
class
FakeTokenizer
:
"""Minimal fake tokenizer that exposes the attributes used by the
parser: a truthy model_tokenizer marker and a vocab mapping for the
special tokens.
"""
def
__init__
(
self
):
self
.
model_tokenizer
=
True
# The parser will look up start/end tokens by their literal strings
self
.
vocab
=
{
"<minimax:tool_call>"
:
1
,
"</minimax:tool_call>"
:
2
,
}
def
get_vocab
(
self
):
return
self
.
vocab
@
pytest
.
fixture
def
minimax_m2_tool_parser
():
return
MinimaxM2ToolParser
(
FakeTokenizer
())
def
test_extract_tool_calls_streaming_incremental
(
minimax_m2_tool_parser
):
parser
=
minimax_m2_tool_parser
parser
.
_reset_streaming_state
()
chunks
=
[
"<minimax:tool_call>"
,
'<invoke name="get_weather">'
,
'<parameter name="city">'
,
"Seattle</parameter>"
,
"</invoke></minimax:tool_call>"
,
]
previous
=
""
for
chunk
in
chunks
:
current
=
previous
+
chunk
delta
=
chunk
parser
.
extract_tool_calls_streaming
(
previous_text
=
previous
,
current_text
=
current
,
delta_text
=
delta
,
previous_token_ids
=
[],
current_token_ids
=
[],
delta_token_ids
=
[],
request
=
None
,
)
previous
=
current
assert
len
(
parser
.
prev_tool_call_arr
)
==
1
entry
=
parser
.
prev_tool_call_arr
[
0
]
assert
entry
[
"name"
]
==
"get_weather"
args
=
entry
[
"arguments"
]
assert
args
[
"city"
]
==
"Seattle"
def
test_streaming_minimax_m2_multiple_invokes
(
minimax_m2_tool_parser
):
parser
=
minimax_m2_tool_parser
parser
.
_reset_streaming_state
()
chunks
=
[
"<minimax:tool_call>"
,
'<invoke name="search_web">'
,
'<parameter name="query_tag">'
,
'["technology", "events"]</parameter>'
,
'<parameter name="query_list">'
,
'["OpenAI", "latest", "release"]</parameter>'
,
"</invoke>"
,
'<invoke name="search_web">'
,
'<parameter name="query_tag">'
,
'["technology", "events"]</parameter>'
,
'<parameter name="query_list">'
,
'["Gemini", "latest", "release"]</parameter>'
,
"</invoke>"
,
"</minimax:tool_call>"
,
]
previous
=
""
for
chunk
in
chunks
:
current
=
previous
+
chunk
delta
=
chunk
parser
.
extract_tool_calls_streaming
(
previous_text
=
previous
,
current_text
=
current
,
delta_text
=
delta
,
previous_token_ids
=
[],
current_token_ids
=
[],
delta_token_ids
=
[],
request
=
None
,
)
previous
=
current
assert
len
(
parser
.
prev_tool_call_arr
)
==
2
for
entry
,
expect_model
in
zip
(
parser
.
prev_tool_call_arr
,
[
"OpenAI"
,
"Gemini"
]):
assert
entry
[
"name"
]
==
"search_web"
args
=
json
.
dumps
(
entry
[
"arguments"
])
assert
"technology"
in
args
and
"events"
in
args
assert
expect_model
in
args
# check streamed_args_for_tool for serving_chat.py
for
index
in
range
(
2
):
expected_call
=
parser
.
prev_tool_call_arr
[
index
].
get
(
"arguments"
,
{})
expected_call
=
json
.
dumps
(
expected_call
)
actual_call
=
parser
.
streamed_args_for_tool
[
index
]
assert
expected_call
==
actual_call
tests/tool_use/test_tool_choice_required.py
View file @
7e63ef82
...
@@ -311,6 +311,7 @@ def test_streaming_output_valid(output, empty_params, delta_len):
...
@@ -311,6 +311,7 @@ def test_streaming_output_valid(output, empty_params, delta_len):
previous_text
=
current_text
previous_text
=
current_text
assert
len
(
messages
)
>
0
assert
len
(
messages
)
>
0
combined_messages
=
"["
combined_messages
=
"["
for
message
in
messages
:
for
message
in
messages
:
if
message
.
tool_calls
[
0
].
function
.
name
:
if
message
.
tool_calls
[
0
].
function
.
name
:
...
@@ -328,3 +329,35 @@ def test_streaming_output_valid(output, empty_params, delta_len):
...
@@ -328,3 +329,35 @@ def test_streaming_output_valid(output, empty_params, delta_len):
combined_messages
+=
"}]"
combined_messages
+=
"}]"
assert
json
.
loads
(
combined_messages
)
==
output
assert
json
.
loads
(
combined_messages
)
==
output
assert
json
.
dumps
(
json
.
loads
(
combined_messages
))
==
output_json
assert
json
.
dumps
(
json
.
loads
(
combined_messages
))
==
output_json
def
test_streaming_output_valid_with_trailing_extra_data
():
self
=
MagicMock
()
output
=
[{
"name"
:
"get_current_weather"
,
"parameters"
:
{
"city"
:
"Vienna"
}}]
output_json
=
json
.
dumps
(
output
)
+
"
\n
DONE"
previous_text
=
""
function_name_returned
=
False
messages
=
[]
delta_len
=
3
for
i
in
range
(
0
,
len
(
output_json
),
delta_len
):
delta_text
=
output_json
[
i
:
i
+
delta_len
]
current_text
=
previous_text
+
delta_text
delta_message
,
function_name_returned
=
(
OpenAIServingChat
.
extract_tool_call_required_streaming
(
self
,
previous_text
=
previous_text
,
current_text
=
current_text
,
delta_text
=
delta_text
,
function_name_returned
=
function_name_returned
,
)
)
if
delta_message
:
messages
.
append
(
delta_message
)
previous_text
=
current_text
assert
len
(
messages
)
>
0
tests/utils.py
View file @
7e63ef82
...
@@ -112,6 +112,7 @@ class RemoteOpenAIServer:
...
@@ -112,6 +112,7 @@ class RemoteOpenAIServer:
env
.
update
(
env_dict
)
env
.
update
(
env_dict
)
serve_cmd
=
[
"vllm"
,
"serve"
,
model
,
*
vllm_serve_args
]
serve_cmd
=
[
"vllm"
,
"serve"
,
model
,
*
vllm_serve_args
]
print
(
f
"Launching RemoteOpenAIServer with:
{
' '
.
join
(
serve_cmd
)
}
"
)
print
(
f
"Launching RemoteOpenAIServer with:
{
' '
.
join
(
serve_cmd
)
}
"
)
print
(
f
"Environment variables:
{
env
}
"
)
self
.
proc
:
subprocess
.
Popen
=
subprocess
.
Popen
(
self
.
proc
:
subprocess
.
Popen
=
subprocess
.
Popen
(
serve_cmd
,
serve_cmd
,
env
=
env
,
env
=
env
,
...
@@ -726,13 +727,34 @@ def init_test_distributed_environment(
...
@@ -726,13 +727,34 @@ def init_test_distributed_environment(
distributed_init_port
:
str
,
distributed_init_port
:
str
,
local_rank
:
int
=
-
1
,
local_rank
:
int
=
-
1
,
)
->
None
:
)
->
None
:
distributed_init_method
=
f
"tcp://localhost:
{
distributed_init_port
}
"
# Note: This function is often called from Ray worker processes, so we
init_distributed_environment
(
# can't rely on pytest fixtures to set the config. We check if the config
world_size
=
pp_size
*
tp_size
,
# is already set and only create a default one if needed.
rank
=
rank
,
from
vllm.config
import
(
distributed_init_method
=
distributed_init_method
,
VllmConfig
,
local_rank
=
local_rank
,
get_current_vllm_config_or_none
,
set_current_vllm_config
,
)
)
distributed_init_method
=
f
"tcp://localhost:
{
distributed_init_port
}
"
if
get_current_vllm_config_or_none
()
is
not
None
:
# Config already set, use it directly
init_distributed_environment
(
world_size
=
pp_size
*
tp_size
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
local_rank
=
local_rank
,
)
else
:
# No config set, create a default one for the test
with
set_current_vllm_config
(
VllmConfig
()):
init_distributed_environment
(
world_size
=
pp_size
*
tp_size
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
local_rank
=
local_rank
,
)
ensure_model_parallel_initialized
(
tp_size
,
pp_size
)
ensure_model_parallel_initialized
(
tp_size
,
pp_size
)
...
...
tests/utils_/test_torch_utils.py
View file @
7e63ef82
...
@@ -99,30 +99,18 @@ def _test_stream_thread(main_expected_stream: torch.cuda.Stream):
...
@@ -99,30 +99,18 @@ def _test_stream_thread(main_expected_stream: torch.cuda.Stream):
def
test_current_stream_multithread
():
def
test_current_stream_multithread
():
from
vllm.platforms
import
current_platform
if
not
torch
.
cuda
.
is_available
():
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
"CUDA not available"
)
pytest
.
skip
(
"CUDA not available"
)
if
current_platform
.
is_rocm
():
main_dedicated_stream
=
current_stream
()
main_dedicated_stream
=
current_stream
()
assert
main_dedicated_stream
.
cuda_stream
!=
0
,
(
"ROCm should create a dedicated stream, not use default stream (0x0)"
)
main_stream_again
=
current_stream
()
assert
main_stream_again
==
main_dedicated_stream
,
(
"Multiple calls to current_stream should return the same dedicated stream"
)
_test_stream_thread
(
main_dedicated_stream
)
assert
main_dedicated_stream
.
cuda_stream
!=
0
,
(
else
:
"ROCm/CUDA should create a dedicated stream, not use default stream (0x0)"
main_default_stream
=
torch
.
cuda
.
default_stream
()
)
main_initial_stream
=
current_stream
()
assert
main_initial_stream
==
main_default_stream
,
(
main_stream_again
=
current_stream
()
"First call to current_stream should return default stream on CUDA"
assert
main_stream_again
==
main_dedicated_stream
,
(
)
"Multiple calls to current_stream should return the same dedicated stream"
)
_test_stream_thread
(
main_de
fault
_stream
)
_test_stream_thread
(
main_de
dicated
_stream
)
tests/v1/attention/test_attention_backends.py
View file @
7e63ef82
...
@@ -15,13 +15,17 @@ from tests.v1.attention.utils import (
...
@@ -15,13 +15,17 @@ from tests.v1.attention.utils import (
create_vllm_config
,
create_vllm_config
,
try_get_attention_backend
,
try_get_attention_backend
,
)
)
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
is_torch_equal_or_newer
from
vllm.utils.torch_utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
is_torch_equal_or_newer
,
set_random_seed
,
)
from
vllm.v1.attention.backend
import
AttentionType
,
CommonAttentionMetadata
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.backends.utils
import
(
from
vllm.v1.attention.backends.utils
import
(
CommonAttentionMetadata
,
set_kv_cache_layout
,
set_kv_cache_layout
,
)
)
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
...
@@ -79,6 +83,13 @@ BATCH_SPECS = {
...
@@ -79,6 +83,13 @@ BATCH_SPECS = {
),
),
"single_decode"
:
BatchSpec
(
seq_lens
=
[
1024
],
query_lens
=
[
1
]),
"single_decode"
:
BatchSpec
(
seq_lens
=
[
1024
],
query_lens
=
[
1
]),
"single_prefill"
:
BatchSpec
(
seq_lens
=
[
1024
],
query_lens
=
[
64
]),
"single_prefill"
:
BatchSpec
(
seq_lens
=
[
1024
],
query_lens
=
[
64
]),
# encoder-only
"small_encoder_prefill"
:
BatchSpec
(
seq_lens
=
[
32
,
64
,
128
,
256
],
query_lens
=
[
32
,
64
,
128
,
256
]
),
"medium_encoder_prefill"
:
BatchSpec
(
seq_lens
=
[
256
,
512
,
1024
,
2048
],
query_lens
=
[
256
,
512
,
1024
,
2048
]
),
}
}
...
@@ -114,17 +125,17 @@ def create_and_prepopulate_kv_cache(
...
@@ -114,17 +125,17 @@ def create_and_prepopulate_kv_cache(
Tuple of (kv_cache, updated_block_table)
Tuple of (kv_cache, updated_block_table)
"""
"""
batch_size
=
len
(
k_contexts
)
batch_size
=
len
(
k_contexts
)
seq_lens
=
common_attn_metadata
.
seq_lens
_
cpu
seq_lens
=
common_attn_metadata
.
seq_lens
.
cpu
()
query_lens
=
(
query_lens
=
(
common_attn_metadata
.
query_start_loc_cpu
[
1
:]
common_attn_metadata
.
query_start_loc_cpu
[
1
:]
-
common_attn_metadata
.
query_start_loc_cpu
[:
-
1
]
-
common_attn_metadata
.
query_start_loc_cpu
[:
-
1
]
)
)
context_lens
=
common_attn_metadata
.
num_computed_tokens_cpu
context_lens
=
seq_lens
-
query_lens
block_table
=
common_attn_metadata
.
block_table_tensor
block_table
=
common_attn_metadata
.
block_table_tensor
slot_mapping
=
common_attn_metadata
.
slot_mapping
slot_mapping
=
common_attn_metadata
.
slot_mapping
# Create KV cache
# Create KV cache
kv_cache
=
torch
.
empty
(
kv_cache
=
torch
.
zeros
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
,
device
=
device
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
,
device
=
device
)
)
kv_cache_flat
=
kv_cache
.
view
(
2
,
-
1
,
num_kv_heads
,
head_size
)
kv_cache_flat
=
kv_cache
.
view
(
2
,
-
1
,
num_kv_heads
,
head_size
)
...
@@ -205,6 +216,7 @@ def run_attention_backend(
...
@@ -205,6 +216,7 @@ def run_attention_backend(
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
sliding_window
:
int
|
None
=
None
,
sliding_window
:
int
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Run attention computation using the specified backend's AttentionImpl."""
"""Run attention computation using the specified backend's AttentionImpl."""
...
@@ -272,6 +284,7 @@ def run_attention_backend(
...
@@ -272,6 +284,7 @@ def run_attention_backend(
num_kv_heads
=
num_kv_heads
,
num_kv_heads
=
num_kv_heads
,
alibi_slopes
=
None
,
alibi_slopes
=
None
,
sliding_window
=
sliding_window
,
sliding_window
=
sliding_window
,
attn_type
=
attn_type
,
kv_cache_dtype
=
"auto"
,
kv_cache_dtype
=
"auto"
,
)
)
...
@@ -295,6 +308,7 @@ def _test_backend_correctness(
...
@@ -295,6 +308,7 @@ def _test_backend_correctness(
backend_to_test
:
list
[
AttentionBackendEnum
|
str
],
backend_to_test
:
list
[
AttentionBackendEnum
|
str
],
mask_mod
,
mask_mod
,
*
,
*
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
block_size
:
int
=
16
,
block_size
:
int
=
16
,
atol
:
float
=
1e-2
,
atol
:
float
=
1e-2
,
rtol
:
float
=
1e-2
,
rtol
:
float
=
1e-2
,
...
@@ -320,7 +334,7 @@ def _test_backend_correctness(
...
@@ -320,7 +334,7 @@ def _test_backend_correctness(
multiple GPUs. This tests that backends work correctly with different
multiple GPUs. This tests that backends work correctly with different
head counts.
head counts.
"""
"""
current_platform
.
seed_everything
(
42
)
set_random_seed
(
42
)
hf_config_override
=
None
hf_config_override
=
None
if
tensor_parallel_size
>
1
:
if
tensor_parallel_size
>
1
:
...
@@ -432,6 +446,9 @@ def _test_backend_correctness(
...
@@ -432,6 +446,9 @@ def _test_backend_correctness(
common_attn_metadata
=
create_common_attn_metadata
(
common_attn_metadata
=
create_common_attn_metadata
(
batch_spec
,
vllm_config
.
cache_config
.
block_size
,
device
batch_spec
,
vllm_config
.
cache_config
.
block_size
,
device
)
)
if
attn_type
==
AttentionType
.
ENCODER_ONLY
:
# For encoder-only, all tokens are prefill tokens
common_attn_metadata
.
causal
=
False
# 3. Simulate Paged KV Cache and a realistic slot_mapping
# 3. Simulate Paged KV Cache and a realistic slot_mapping
kv_cache
=
create_and_prepopulate_kv_cache
(
kv_cache
=
create_and_prepopulate_kv_cache
(
...
@@ -487,6 +504,7 @@ def _test_backend_correctness(
...
@@ -487,6 +504,7 @@ def _test_backend_correctness(
value_vllm
,
value_vllm
,
kv_cache_for_backend
,
kv_cache_for_backend
,
sliding_window
=
sliding_window
,
sliding_window
=
sliding_window
,
attn_type
=
attn_type
,
)
)
finally
:
finally
:
if
reset_kv_cache_layout
:
if
reset_kv_cache_layout
:
...
@@ -537,7 +555,7 @@ def _test_backend_correctness(
...
@@ -537,7 +555,7 @@ def _test_backend_correctness(
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"meta-llama/Meta-Llama-3-8B"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"meta-llama/Meta-Llama-3-8B"
])
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
1
,
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
1
,
2
,
4
])
def
test_causal_backend_correctness
(
def
test_causal_backend_correctness
(
batch_spec_name
:
str
,
model
:
str
,
tensor_parallel_size
:
int
default_vllm_config
,
batch_spec_name
:
str
,
model
:
str
,
tensor_parallel_size
:
int
):
):
"""Test backend's correctness with causal attention."""
"""Test backend's correctness with causal attention."""
...
@@ -557,9 +575,21 @@ def test_causal_backend_correctness(
...
@@ -557,9 +575,21 @@ def test_causal_backend_correctness(
if
is_torch_equal_or_newer
(
"2.9.0.dev0"
)
if
is_torch_equal_or_newer
(
"2.9.0.dev0"
)
else
[]
else
[]
)
)
SMALL_BLOCK_BACKENDS
=
[
x
for
x
in
BACKENDS_TO_TEST
if
x
not
in
LARGE_BLOCK_BACKENDS
if
current_platform
.
is_rocm
():
]
SMALL_BLOCK_BACKENDS
=
[
x
for
x
in
BACKENDS_TO_TEST
if
(
x
not
in
LARGE_BLOCK_BACKENDS
and
x
is
not
AttentionBackendEnum
.
FLASH_ATTN
)
]
else
:
SMALL_BLOCK_BACKENDS
=
[
x
for
x
in
BACKENDS_TO_TEST
if
x
not
in
LARGE_BLOCK_BACKENDS
]
_test_backend_correctness
(
_test_backend_correctness
(
batch_spec
,
batch_spec
,
model
,
model
,
...
@@ -580,12 +610,20 @@ def test_causal_backend_correctness(
...
@@ -580,12 +610,20 @@ def test_causal_backend_correctness(
)
)
SLIDING_WINDOW_BACKENDS_TO_TEST
=
[
if
current_platform
.
is_rocm
():
AttentionBackendEnum
.
FLASH_ATTN
,
# FLASH_ATTN is not supported on ROCm
AttentionBackendEnum
.
FLEX_ATTENTION
,
SLIDING_WINDOW_BACKENDS_TO_TEST
=
[
AttentionBackendEnum
.
TRITON_ATTN
,
AttentionBackendEnum
.
FLEX_ATTENTION
,
"FLEX_ATTENTION_SLOW"
,
AttentionBackendEnum
.
TRITON_ATTN
,
]
"FLEX_ATTENTION_SLOW"
,
]
else
:
SLIDING_WINDOW_BACKENDS_TO_TEST
=
[
AttentionBackendEnum
.
FLASH_ATTN
,
AttentionBackendEnum
.
FLEX_ATTENTION
,
AttentionBackendEnum
.
TRITON_ATTN
,
"FLEX_ATTENTION_SLOW"
,
]
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -652,3 +690,45 @@ def test_sliding_window_backend_correctness(
...
@@ -652,3 +690,45 @@ def test_sliding_window_backend_correctness(
block_size
=
128
,
block_size
=
128
,
tensor_parallel_size
=
tensor_parallel_size
,
tensor_parallel_size
=
tensor_parallel_size
,
)
)
@
pytest
.
mark
.
parametrize
(
"batch_spec_name"
,
[
"small_encoder_prefill"
,
"medium_encoder_prefill"
,
],
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"google/embeddinggemma-300m"
])
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
1
,
2
])
def
test_sliding_window_encoder_backend_correctness
(
batch_spec_name
:
str
,
model
:
str
,
tensor_parallel_size
:
int
):
"""Test backend's correctness with sliding window attention."""
def
bidi_sliding_window_mask_mod
(
b
:
torch
.
Tensor
,
h
:
torch
.
Tensor
,
q_idx
:
torch
.
Tensor
,
kv_idx
:
torch
.
Tensor
,
*
,
context_len
:
int
,
sliding_window
:
int
,
):
return
torch
.
abs
(
q_idx
+
context_len
-
kv_idx
)
<
sliding_window
batch_spec
=
BATCH_SPECS
[
batch_spec_name
]
model_config
=
ModelConfig
(
model
=
model
,
max_model_len
=
max
(
batch_spec
.
seq_lens
))
sliding_window
=
model_config
.
get_sliding_window
()
sliding_window_mask_mod_fn
=
partial
(
bidi_sliding_window_mask_mod
,
sliding_window
=
sliding_window
)
_test_backend_correctness
(
batch_spec
,
model
,
SLIDING_WINDOW_BACKENDS_TO_TEST
,
sliding_window_mask_mod_fn
,
attn_type
=
AttentionType
.
ENCODER_ONLY
,
tensor_parallel_size
=
tensor_parallel_size
,
)
tests/v1/attention/test_attention_backends_selection.py
View file @
7e63ef82
...
@@ -79,7 +79,12 @@ from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
...
@@ -79,7 +79,12 @@ from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
],
],
)
)
def
test_mamba_layers_get_attn_backend
(
def
test_mamba_layers_get_attn_backend
(
dist_init
,
layer_class
,
init_kwargs
,
expected_backend
,
expected_mamba_type
default_vllm_config
,
dist_init
,
layer_class
,
init_kwargs
,
expected_backend
,
expected_mamba_type
,
):
):
"""Test that Mamba-like layers return the correct attention backend."""
"""Test that Mamba-like layers return the correct attention backend."""
layer
=
layer_class
(
**
init_kwargs
)
layer
=
layer_class
(
**
init_kwargs
)
...
...
Prev
1
…
26
27
28
29
30
31
32
33
34
35
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment