Commit 7e63ef82 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0' into v0.14.0-dev

parents 8cbcac5d b17039bc
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for AITER MLA FP8 support detection.
These tests verify that the _check_aiter_mla_fp8_support() function
correctly handles various error conditions without crashing.
"""
from unittest.mock import patch
import pytest
class TestAiterMlaFp8SupportCheck:
"""Test cases for _check_aiter_mla_fp8_support() function."""
def setup_method(self):
"""Reset the global cache before each test."""
import vllm._aiter_ops as aiter_ops
aiter_ops._AITER_MLA_SUPPORTS_FP8 = None
@patch("vllm._aiter_ops.is_aiter_found_and_supported", return_value=True)
def test_import_error_handling(self, mock_supported):
"""Test that ImportError is handled gracefully."""
import vllm._aiter_ops as aiter_ops
from vllm._aiter_ops import _check_aiter_mla_fp8_support
aiter_ops._AITER_MLA_SUPPORTS_FP8 = None
# Should return False without raising
with patch(
"vllm._aiter_ops.inspect.signature",
side_effect=ImportError("No module"),
):
result = _check_aiter_mla_fp8_support()
assert result is False
@patch("vllm._aiter_ops.is_aiter_found_and_supported", return_value=True)
def test_module_not_found_error_handling(self, mock_supported):
"""Test that ModuleNotFoundError is handled gracefully."""
import vllm._aiter_ops as aiter_ops
from vllm._aiter_ops import _check_aiter_mla_fp8_support
aiter_ops._AITER_MLA_SUPPORTS_FP8 = None
with patch(
"vllm._aiter_ops.inspect.signature",
side_effect=ModuleNotFoundError("Module not found"),
):
# Should return False without raising
assert _check_aiter_mla_fp8_support() is False
# Cache should be set to False
assert aiter_ops._AITER_MLA_SUPPORTS_FP8 is False
@patch("vllm._aiter_ops.is_aiter_found_and_supported", return_value=True)
def test_attribute_error_handling(self, mock_supported):
"""Test that AttributeError is handled gracefully."""
import vllm._aiter_ops as aiter_ops
from vllm._aiter_ops import _check_aiter_mla_fp8_support
aiter_ops._AITER_MLA_SUPPORTS_FP8 = None
with patch(
"vllm._aiter_ops.inspect.signature",
side_effect=AttributeError("No attribute"),
):
assert _check_aiter_mla_fp8_support() is False
assert aiter_ops._AITER_MLA_SUPPORTS_FP8 is False
@patch("vllm._aiter_ops.is_aiter_found_and_supported", return_value=True)
def test_value_error_handling(self, mock_supported):
"""Test that ValueError is handled gracefully (no signature)."""
import vllm._aiter_ops as aiter_ops
from vllm._aiter_ops import _check_aiter_mla_fp8_support
aiter_ops._AITER_MLA_SUPPORTS_FP8 = None
with patch(
"vllm._aiter_ops.inspect.signature",
side_effect=ValueError("No signature"),
):
assert _check_aiter_mla_fp8_support() is False
assert aiter_ops._AITER_MLA_SUPPORTS_FP8 is False
@patch("vllm._aiter_ops.is_aiter_found_and_supported", return_value=True)
def test_type_error_handling(self, mock_supported):
"""Test that TypeError is handled gracefully (not callable)."""
import vllm._aiter_ops as aiter_ops
from vllm._aiter_ops import _check_aiter_mla_fp8_support
aiter_ops._AITER_MLA_SUPPORTS_FP8 = None
with patch(
"vllm._aiter_ops.inspect.signature",
side_effect=TypeError("Not a callable"),
):
assert _check_aiter_mla_fp8_support() is False
assert aiter_ops._AITER_MLA_SUPPORTS_FP8 is False
@patch("vllm._aiter_ops.is_aiter_found_and_supported", return_value=True)
def test_result_caching(self, mock_supported):
"""Test that the result is cached after first check."""
import vllm._aiter_ops as aiter_ops
# Set cache to True
aiter_ops._AITER_MLA_SUPPORTS_FP8 = True
from vllm._aiter_ops import _check_aiter_mla_fp8_support
# Should return cached value without re-checking
result = _check_aiter_mla_fp8_support()
assert result is True
if __name__ == "__main__":
pytest.main([__file__, "-v"])
...@@ -5,9 +5,6 @@ ...@@ -5,9 +5,6 @@
# The utility function cannot be placed in `vllm.utils` # The utility function cannot be placed in `vllm.utils`
# this needs to be a standalone script # this needs to be a standalone script
import sys import sys
from contextlib import nullcontext
from vllm_test_utils import BlameResult, blame
# List of modules that should not be imported too early. # List of modules that should not be imported too early.
# Lazy import `torch._inductor.async_compile` to avoid creating # Lazy import `torch._inductor.async_compile` to avoid creating
...@@ -16,26 +13,10 @@ from vllm_test_utils import BlameResult, blame ...@@ -16,26 +13,10 @@ from vllm_test_utils import BlameResult, blame
# `cv2` can easily mess up the environment. # `cv2` can easily mess up the environment.
module_names = ["torch._inductor.async_compile", "cv2"] module_names = ["torch._inductor.async_compile", "cv2"]
# set all modules in `module_names` to be None.
# if we import any modules during `import vllm`, there would be a
# hard error and nice stacktrace on the first import.
for module_name in module_names:
sys.modules[module_name] = None # type: ignore[assignment]
def any_module_imported(): import vllm # noqa
return any(module_name in sys.modules for module_name in module_names)
# In CI, we only check finally if the module is imported.
# If it is indeed imported, we can rerun the test with `use_blame=True`,
# which will trace every function call to find the first import location,
# and help find the root cause.
# We don't run it in CI by default because it is slow.
use_blame = False
context = blame(any_module_imported) if use_blame else nullcontext()
with context as result:
import vllm # noqa
if use_blame:
assert isinstance(result, BlameResult)
print(f"the first import location is:\n{result.trace_stack}")
assert not any_module_imported(), (
f"Some the modules in {module_names} are imported. To see the first"
f" import location, run the test with `use_blame=True`."
)
...@@ -18,25 +18,37 @@ for i in {1..5}; do ...@@ -18,25 +18,37 @@ for i in {1..5}; do
echo "Checking metadata.json URL (attempt $i)..." echo "Checking metadata.json URL (attempt $i)..."
if curl --fail "$meta_json_url" > metadata.json; then if curl --fail "$meta_json_url" > metadata.json; then
echo "INFO: metadata.json URL is valid." echo "INFO: metadata.json URL is valid."
# check whether it is valid json by python # check whether it is valid json by python (printed to stdout)
if python3 -m json.tool metadata.json; then if python3 -m json.tool metadata.json; then
echo "INFO: metadata.json is valid JSON. Proceeding with the test." echo "INFO: metadata.json is valid JSON. Proceeding with the check."
# check whether there is an object in the json matching:
# "package_name": "vllm", and "platform_tag" matches the current architecture
# see `determine_wheel_url` in setup.py for more details
if python3 -c "import platform as p,json as j,sys as s; d = j.load(open('metadata.json')); \
s.exit(int(not any(o.get('package_name') == 'vllm' and p.machine() in o.get('platform_tag') \
for o in d)))" 2>/dev/null; then
echo "INFO: metadata.json contains a pre-compiled wheel for the current architecture."
break
else
echo "WARN: metadata.json does not have a pre-compiled wheel for the current architecture."
fi
else else
echo "CRITICAL: metadata.json exists but is not valid JSON, please do report in #sig-ci channel!" echo "CRITICAL: metadata.json exists but is not valid JSON, please do report in #sig-ci channel!"
echo "INFO: metadata.json content:"
cat metadata.json
exit 1 exit 1
fi fi
break
fi fi
# failure handling # failure handling & retry logic
if [ $i -eq 5 ]; then if [ $i -eq 5 ]; then
echo "ERROR: metadata.json URL is still not valid after 5 attempts." echo "ERROR: metadata is still not available after 5 attempts."
echo "ERROR: Please check whether the precompiled wheel for commit $merge_base_commit exists." echo "ERROR: Please check whether the precompiled wheel for commit $merge_base_commit is available."
echo " NOTE: If $merge_base_commit is a new commit on main, maybe try again after its release pipeline finishes." echo " NOTE: If $merge_base_commit is a new commit on main, maybe try again after its release pipeline finishes."
echo " NOTE: If it fails, please report in #sig-ci channel." echo " NOTE: If it fails, please report in #sig-ci channel."
exit 1 exit 1
else else
echo "WARNING: metadata.json URL is not valid. Retrying in 3 minutes..." echo "WARNING: metadata is not available. Retrying after 5 minutes..."
sleep 180 sleep 300
fi fi
done done
......
...@@ -4,6 +4,11 @@ ...@@ -4,6 +4,11 @@
set -e set -e
set -x set -x
if command -v rocminfo >/dev/null 2>&1; then
echo "Skipping test for ROCm platform"
exit 0
fi
cd /vllm-workspace/ cd /vllm-workspace/
rm -rf .venv rm -rf .venv
...@@ -36,7 +41,7 @@ if diff before.txt after.txt; then ...@@ -36,7 +41,7 @@ if diff before.txt after.txt; then
echo "torch version not overridden." echo "torch version not overridden."
else else
echo "torch version overridden by nightly_torch_test.txt, \ echo "torch version overridden by nightly_torch_test.txt, \
if the dependency is not triggered by the pytroch nightly test,\ if the dependency is not triggered by the pytorch nightly test,\
please add the dependency to the list 'white_list' in tools/pre_commit/generate_nightly_torch_test.py" please add the dependency to the list 'white_list' in tools/pre_commit/generate_nightly_torch_test.py"
exit 1 exit 1
fi fi
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
)
from vllm.v1.attention.backends.registry import (
AttentionBackendEnum,
MambaAttentionBackendEnum,
register_backend,
)
class CustomAttentionImpl(AttentionImpl):
"""Mock custom attention implementation for testing."""
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, *args, **kwargs):
"""Mock forward pass."""
pass
class CustomAttentionBackend(AttentionBackend):
"""Mock custom attention backend for testing."""
@staticmethod
def get_name():
return "CUSTOM"
@staticmethod
def get_impl_cls():
return CustomAttentionImpl
@staticmethod
def get_builder_cls():
"""Mock builder class."""
return None
@staticmethod
def get_required_kv_cache_layout():
"""Mock KV cache layout."""
return None
class CustomMambaAttentionImpl(AttentionImpl):
"""Mock custom mamba attention implementation for testing."""
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, *args, **kwargs):
"""Mock forward pass."""
pass
class CustomMambaAttentionBackend(AttentionBackend):
"""Mock custom mamba attention backend for testing."""
@staticmethod
def get_name():
return "CUSTOM_MAMBA"
@staticmethod
def get_impl_cls():
return CustomMambaAttentionImpl
@staticmethod
def get_builder_cls():
"""Mock builder class."""
return None
@staticmethod
def get_required_kv_cache_layout():
"""Mock KV cache layout."""
return None
def test_custom_is_not_alias_of_any_backend():
# Get all members of AttentionBackendEnum
all_backends = list(AttentionBackendEnum)
# Find any aliases of CUSTOM
aliases = []
for backend in all_backends:
if backend.name != "CUSTOM" and backend is AttentionBackendEnum.CUSTOM:
aliases.append(backend.name)
# CUSTOM should not be an alias of any other backend
assert len(aliases) == 0, (
f"BUG! CUSTOM is an alias of: {', '.join(aliases)}!\n"
f"CUSTOM.value = {repr(AttentionBackendEnum.CUSTOM.value)}\n"
f"This happens when CUSTOM has the same value as another backend.\n"
f"When you register to CUSTOM, you're actually registering to {aliases[0]}!\n"
f"All backend values:\n"
+ "\n".join(f" {b.name}: {repr(b.value)}" for b in all_backends)
)
# Verify CUSTOM has its own unique identity
assert AttentionBackendEnum.CUSTOM.name == "CUSTOM", (
f"CUSTOM.name should be 'CUSTOM', but got '{AttentionBackendEnum.CUSTOM.name}'"
)
def test_register_custom_backend_with_class_path():
# Register with explicit class path
register_backend(
backend=AttentionBackendEnum.CUSTOM,
class_path="tests.test_attention_backend_registry.CustomAttentionBackend",
is_mamba=False,
)
# Check that CUSTOM backend is registered
assert AttentionBackendEnum.CUSTOM.is_overridden(), (
"CUSTOM should be overridden after registration"
)
# Get the registered class path
class_path = AttentionBackendEnum.CUSTOM.get_path()
assert class_path == "tests.test_attention_backend_registry.CustomAttentionBackend"
# Get the backend class
backend_cls = AttentionBackendEnum.CUSTOM.get_class()
assert backend_cls.get_name() == "CUSTOM"
assert backend_cls.get_impl_cls() == CustomAttentionImpl
def test_mamba_custom_is_not_alias_of_any_backend():
# Get all mamba backends
all_backends = list(MambaAttentionBackendEnum)
# Find any aliases of CUSTOM
aliases = []
for backend in all_backends:
if backend.name != "CUSTOM" and backend is MambaAttentionBackendEnum.CUSTOM:
aliases.append(backend.name)
# CUSTOM should not be an alias of any other backend
assert len(aliases) == 0, (
f"BUG! MambaAttentionBackendEnum.CUSTOM is an alias of: {', '.join(aliases)}!\n"
f"CUSTOM.value = {repr(MambaAttentionBackendEnum.CUSTOM.value)}\n"
f"All mamba backend values:\n"
+ "\n".join(f" {b.name}: {repr(b.value)}" for b in all_backends)
)
def test_register_custom_mamba_backend_with_class_path():
# Register with explicit class path
register_backend(
backend=MambaAttentionBackendEnum.CUSTOM,
class_path="tests.test_attention_backend_registry.CustomMambaAttentionBackend",
is_mamba=True,
)
# Check that the backend is registered
assert MambaAttentionBackendEnum.CUSTOM.is_overridden()
# Get the registered class path
class_path = MambaAttentionBackendEnum.CUSTOM.get_path()
assert (
class_path
== "tests.test_attention_backend_registry.CustomMambaAttentionBackend"
)
# Get the backend class
backend_cls = MambaAttentionBackendEnum.CUSTOM.get_class()
assert backend_cls.get_name() == "CUSTOM_MAMBA"
assert backend_cls.get_impl_cls() == CustomMambaAttentionImpl
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging import logging
import os import os
from dataclasses import MISSING, Field, asdict, dataclass, field from dataclasses import MISSING, Field, asdict, dataclass, field
...@@ -25,7 +26,6 @@ from vllm.config.vllm import ( ...@@ -25,7 +26,6 @@ from vllm.config.vllm import (
OPTIMIZATION_LEVEL_TO_CONFIG, OPTIMIZATION_LEVEL_TO_CONFIG,
OptimizationLevel, OptimizationLevel,
) )
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform from vllm.platforms import current_platform
from utils import models_path_prefix from utils import models_path_prefix
...@@ -162,8 +162,9 @@ def test_get_pooling_config(): ...@@ -162,8 +162,9 @@ def test_get_pooling_config():
model_config = ModelConfig(model_id) model_config = ModelConfig(model_id)
assert model_config.pooler_config is not None assert model_config.pooler_config is not None
assert model_config.pooler_config.normalize assert model_config.pooler_config.use_activation
assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name assert model_config.pooler_config.seq_pooling_type == "MEAN"
assert model_config.pooler_config.tok_pooling_type == "ALL"
@pytest.mark.skipif( @pytest.mark.skipif(
...@@ -171,7 +172,7 @@ def test_get_pooling_config(): ...@@ -171,7 +172,7 @@ def test_get_pooling_config():
) )
def test_get_pooling_config_from_args(): def test_get_pooling_config_from_args():
model_id = os.path.join(models_path_prefix, "sentence-transformers/all-MiniLM-L12-v2") model_id = os.path.join(models_path_prefix, "sentence-transformers/all-MiniLM-L12-v2")
pooler_config = PoolerConfig(pooling_type="CLS", normalize=True) pooler_config = PoolerConfig(seq_pooling_type="CLS", normalize=True)
model_config = ModelConfig(model_id, pooler_config=pooler_config) model_config = ModelConfig(model_id, pooler_config=pooler_config)
assert asdict(model_config.pooler_config) == asdict(pooler_config) assert asdict(model_config.pooler_config) == asdict(pooler_config)
...@@ -182,14 +183,25 @@ def test_get_pooling_config_from_args(): ...@@ -182,14 +183,25 @@ def test_get_pooling_config_from_args():
[ [
("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "LAST", "LAST"), # LLM ("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "LAST", "LAST"), # LLM
("intfloat/e5-small", "CLS", "MEAN"), # BertModel ("intfloat/e5-small", "CLS", "MEAN"), # BertModel
],
)
def test_default_seq_pooling_type(model_id, default_pooling_type, pooling_type):
model_config = ModelConfig(model_id)
assert model_config._model_info.default_seq_pooling_type == default_pooling_type
assert model_config.pooler_config.seq_pooling_type == pooling_type
@pytest.mark.parametrize(
("model_id", "default_pooling_type", "pooling_type"),
[
("Qwen/Qwen2.5-Math-RM-72B", "ALL", "ALL"), # reward ("Qwen/Qwen2.5-Math-RM-72B", "ALL", "ALL"), # reward
("Qwen/Qwen2.5-Math-PRM-7B", "STEP", "STEP"), # step reward ("Qwen/Qwen2.5-Math-PRM-7B", "STEP", "STEP"), # step reward
], ],
) )
def test_default_pooling_type(model_id, default_pooling_type, pooling_type): def test_default_tok_pooling_type(model_id, default_pooling_type, pooling_type):
model_config = ModelConfig(model_id) model_config = ModelConfig(model_id)
assert model_config._model_info.default_pooling_type == default_pooling_type assert model_config._model_info.default_tok_pooling_type == default_pooling_type
assert model_config.pooler_config.pooling_type == pooling_type assert model_config.pooler_config.tok_pooling_type == pooling_type
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -207,8 +219,8 @@ def test_default_pooling_type(model_id, default_pooling_type, pooling_type): ...@@ -207,8 +219,8 @@ def test_default_pooling_type(model_id, default_pooling_type, pooling_type):
) )
def test_moe_model_detection(model_id, expected_is_moe_model): def test_moe_model_detection(model_id, expected_is_moe_model):
model_config = ModelConfig(model_id) model_config = ModelConfig(model_id)
# Just check that is_moe_model field exists and is a boolean # Just check that is_moe field exists and is a boolean
assert model_config.is_model_moe() == expected_is_moe_model assert model_config.is_moe == expected_is_moe_model
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -226,7 +238,7 @@ def test_moe_model_detection(model_id, expected_is_moe_model): ...@@ -226,7 +238,7 @@ def test_moe_model_detection(model_id, expected_is_moe_model):
def test_is_quantized(model_id, quantized): def test_is_quantized(model_id, quantized):
model_config = ModelConfig(model_id) model_config = ModelConfig(model_id)
# Just check that quantized field exists and is a boolean # Just check that quantized field exists and is a boolean
assert model_config.is_quantized() == quantized assert model_config.is_quantized == quantized
@pytest.mark.skipif( @pytest.mark.skipif(
...@@ -556,100 +568,100 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files): ...@@ -556,100 +568,100 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
"jason9693/Qwen2.5-1.5B-apeach", "jason9693/Qwen2.5-1.5B-apeach",
"decoder", "decoder",
True, True,
"Pooling models with causal attn and last pooling support chunked prefill.", "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
), ),
( (
"Qwen/Qwen3-Embedding-0.6B", "Qwen/Qwen3-Embedding-0.6B",
"decoder", "decoder",
True, True,
"Pooling models with causal attn and last pooling support chunked prefill.", "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
), ),
( (
"Qwen/Qwen2.5-Math-PRM-7B", "Qwen/Qwen2.5-Math-PRM-7B",
"decoder", "decoder",
False, False,
"Pooling models with step pooling does not support chunked prefill.", "Pooling models with causal attn and LAST/STEP pooling do not support chunked prefill.", # noqa: E501
), ),
( (
"internlm/internlm2-1_8b-reward", "internlm/internlm2-1_8b-reward",
"decoder", "decoder",
True, True,
"Pooling models with causal attn and all pooling support chunked prefill.", "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
), ),
( (
"BAAI/bge-base-en", "BAAI/bge-base-en",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support chunked prefill.", "Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
), ),
( (
"boltuix/NeuroBERT-NER", "boltuix/NeuroBERT-NER",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support chunked prefill.", "Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
), ),
( (
"papluca/xlm-roberta-base-language-detection", "papluca/xlm-roberta-base-language-detection",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support chunked prefill.", "Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
), ),
( (
"Alibaba-NLP/gte-Qwen2-1.5B-instruct", "Alibaba-NLP/gte-Qwen2-1.5B-instruct",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support chunked prefill.", "Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
), ),
( (
"intfloat/e5-small", "intfloat/e5-small",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support chunked prefill.", "Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
), ),
# multimodal models # multimodal models
( (
"openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32",
"decoder", "decoder",
True, True,
"Pooling models with causal attn and last pooling support chunked prefill.", "Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
), ),
( (
"google/siglip-base-patch16-224", "google/siglip-base-patch16-224",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support chunked prefill.", "Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
), ),
# generate models # generate models
( (
"Qwen/Qwen3-0.6B", "Qwen/Qwen3-0.6B",
"decoder", "decoder",
True, True,
"Generative models support chunked prefill.", "Generative models support chunked prefill.", # noqa: E501
), ),
( (
"Qwen/Qwen3-Next-80B-A3B-Instruct", "Qwen/Qwen3-Next-80B-A3B-Instruct",
"hybrid", "hybrid",
True, True,
"Generative models support chunked prefill.", "Generative models support chunked prefill.", # noqa: E501
), ),
( (
"ibm-granite/granite-4.0-h-small", "ibm-granite/granite-4.0-h-small",
"hybrid", "hybrid",
True, True,
"Generative models support chunked prefill.", "Generative models support chunked prefill.", # noqa: E501
), ),
( (
"state-spaces/mamba-130m-hf", "state-spaces/mamba-130m-hf",
"attention_free", "attention_free",
True, True,
"Generative models support chunked prefill.", "Generative models support chunked prefill.", # noqa: E501
), ),
# encoder_decoder models # encoder_decoder models
( (
"openai/whisper-small", "openai/whisper-small",
"encoder_decoder", "encoder_decoder",
False, False,
"Encoder decoder models does not support chunked prefill.", "Encoder decoder models do not support chunked prefill.", # noqa: E501
), ),
], ],
) )
...@@ -675,100 +687,100 @@ def test_is_chunked_prefill_supported( ...@@ -675,100 +687,100 @@ def test_is_chunked_prefill_supported(
"jason9693/Qwen2.5-1.5B-apeach", "jason9693/Qwen2.5-1.5B-apeach",
"decoder", "decoder",
True, True,
"Pooling models with causal attn and last pooling support prefix caching.", "Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
), ),
( (
"Qwen/Qwen3-Embedding-0.6B", "Qwen/Qwen3-Embedding-0.6B",
"decoder", "decoder",
True, True,
"Pooling models with causal attn and last pooling support prefix caching.", "Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
), ),
( (
"Qwen/Qwen2.5-Math-PRM-7B", "Qwen/Qwen2.5-Math-PRM-7B",
"decoder", "decoder",
False, False,
"Pooling models with step pooling does not support prefix caching.", "Pooling models with causal attn and LAST/STEP pooling do not support prefix caching.", # noqa: E501
), ),
( (
"internlm/internlm2-1_8b-reward", "internlm/internlm2-1_8b-reward",
"decoder", "decoder",
True, True,
"Pooling models with causal attn and all pooling support prefix caching.", "Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
), ),
( (
"BAAI/bge-base-en", "BAAI/bge-base-en",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support prefix caching.", "Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
), ),
( (
"boltuix/NeuroBERT-NER", "boltuix/NeuroBERT-NER",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support prefix caching.", "Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
), ),
( (
"papluca/xlm-roberta-base-language-detection", "papluca/xlm-roberta-base-language-detection",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support prefix caching.", "Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
), ),
( (
"Alibaba-NLP/gte-Qwen2-1.5B-instruct", "Alibaba-NLP/gte-Qwen2-1.5B-instruct",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support prefix caching.", "Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
), ),
( (
"intfloat/e5-small", "intfloat/e5-small",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support prefix caching.", "Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
), ),
# multimodal models # multimodal models
( (
"openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32",
"decoder", "decoder",
True, True,
"Pooling models with causal attn and last pooling support prefix caching.", "Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
), ),
( (
"google/siglip-base-patch16-224", "google/siglip-base-patch16-224",
"encoder_only", "encoder_only",
False, False,
"Pooling models with bidirectional attn does not support prefix caching.", "Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
), ),
# generate models # generate models
( (
"Qwen/Qwen3-0.6B", "Qwen/Qwen3-0.6B",
"decoder", "decoder",
True, True,
"Generative models support prefix caching.", "Generative models support prefix caching.", # noqa: E501
), ),
( (
"Qwen/Qwen3-Next-80B-A3B-Instruct", "Qwen/Qwen3-Next-80B-A3B-Instruct",
"hybrid", "hybrid",
False, False,
"Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501 "Hybrid models do not support prefix caching since the feature is still experimental.", # noqa: E501
), ),
( (
"ibm-granite/granite-4.0-h-small", "ibm-granite/granite-4.0-h-small",
"hybrid", "hybrid",
False, False,
"Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501 "Hybrid models do not support prefix caching since the feature is still experimental.", # noqa: E501
), ),
( (
"state-spaces/mamba-130m-hf", "state-spaces/mamba-130m-hf",
"attention_free", "attention_free",
False, False,
"Attention free models does not support prefix caching since the feature is still experimental.", # noqa: E501 "Attention free models do not support prefix caching since the feature is still experimental.", # noqa: E501
), ),
# encoder_decoder models # encoder_decoder models
( (
"openai/whisper-small", "openai/whisper-small",
"encoder_decoder", "encoder_decoder",
False, False,
"Encoder decoder models does not support prefix caching.", "Encoder decoder models do not support prefix caching.", # noqa: E501
), ),
], ],
) )
...@@ -927,7 +939,7 @@ def test_vllm_config_callable_defaults(): ...@@ -927,7 +939,7 @@ def test_vllm_config_callable_defaults():
model_config=quantized_model, optimization_level=OptimizationLevel.O2 model_config=quantized_model, optimization_level=OptimizationLevel.O2
) )
enable_if_quantized = lambda cfg: ( enable_if_quantized = lambda cfg: (
cfg.model_config is not None and cfg.model_config.is_quantized() cfg.model_config is not None and cfg.model_config.is_quantized
) )
assert enable_if_quantized(config_quantized) is True assert enable_if_quantized(config_quantized) is True
assert enable_if_quantized(config_no_model) is False assert enable_if_quantized(config_no_model) is False
...@@ -938,7 +950,7 @@ def test_vllm_config_callable_defaults(): ...@@ -938,7 +950,7 @@ def test_vllm_config_callable_defaults():
model_config=moe_model, optimization_level=OptimizationLevel.O2 model_config=moe_model, optimization_level=OptimizationLevel.O2
) )
enable_if_sequential = lambda cfg: ( enable_if_sequential = lambda cfg: (
cfg.model_config is not None and not cfg.model_config.is_model_moe() cfg.model_config is not None and not cfg.model_config.is_moe
) )
assert enable_if_sequential(config_moe) is False assert enable_if_sequential(config_moe) is False
assert enable_if_sequential(config_quantized) is True assert enable_if_sequential(config_quantized) is True
...@@ -1052,3 +1064,46 @@ def test_scheduler_config_init(): ...@@ -1052,3 +1064,46 @@ def test_scheduler_config_init():
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
# InitVar does not become an attribute # InitVar does not become an attribute
print(SchedulerConfig.default_factory().max_model_len) print(SchedulerConfig.default_factory().max_model_len)
@pytest.mark.parametrize(
(
"model_id",
"data_parallel_size",
"external_lb",
"expected_needs_coordinator",
),
[
# Non-MoE model with DP=1 should not need coordinator
("facebook/opt-125m", 1, False, False),
# Non-MoE model with DP>1 internal LB should need coordinator
("facebook/opt-125m", 2, False, True),
# Non-MoE model with DP>1 external LB should not need coordinator
("facebook/opt-125m", 2, True, False),
# MoE model with DP=1 should not need coordinator
("mistralai/Mixtral-8x7B-Instruct-v0.1", 1, False, False),
# MoE model with DP>1 internal LB should need both coordinator
# and wave coordination
("mistralai/Mixtral-8x7B-Instruct-v0.1", 2, False, True),
# MoE model with DP>1 external LB needs coordinator for wave coordination
# (wave coordination runs in coordinator process)
("mistralai/Mixtral-8x7B-Instruct-v0.1", 2, True, True),
],
)
def test_needs_dp_coordination(
model_id,
data_parallel_size,
external_lb,
expected_needs_coordinator,
):
"""Test that DP coordinator and wave coordination are configured correctly."""
from vllm.config import ParallelConfig
model_config = ModelConfig(model_id)
parallel_config = ParallelConfig(
data_parallel_size=data_parallel_size,
data_parallel_external_lb=external_lb,
)
vllm_config = VllmConfig(model_config=model_config, parallel_config=parallel_config)
assert vllm_config.needs_dp_coordinator == expected_needs_coordinator
...@@ -18,7 +18,7 @@ EMBEDDING_MODELS = [ ...@@ -18,7 +18,7 @@ EMBEDDING_MODELS = [
] ]
classify_parameters = ["use_activation"] classify_parameters = ["use_activation"]
embed_parameters = ["dimensions", "normalize"] embed_parameters = ["dimensions", "use_activation"]
step_pooling_parameters = ["step_tag_id", "returned_token_ids"] step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
...@@ -40,19 +40,19 @@ def test_task(): ...@@ -40,19 +40,19 @@ def test_task():
def test_embed(): def test_embed():
task = "embed" task = "embed"
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS")) model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
pooling_params = PoolingParams(normalize=None) pooling_params = PoolingParams(use_activation=None)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(normalize=True) pooling_params = PoolingParams(use_activation=True)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(normalize=False) pooling_params = PoolingParams(use_activation=False)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
invalid_parameters = classify_parameters + step_pooling_parameters invalid_parameters = classify_parameters + step_pooling_parameters
for p in invalid_parameters: for p in set(invalid_parameters) - set(embed_parameters):
with pytest.raises(ValueError): with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True}) pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
...@@ -86,7 +86,7 @@ def test_embed_dimensions(model_info: EmbedModelInfo): ...@@ -86,7 +86,7 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
@pytest.mark.parametrize("task", ["score", "classify"]) @pytest.mark.parametrize("task", ["score", "classify"])
def test_classify(task): def test_classify(task):
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS")) model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
pooling_params = PoolingParams(use_activation=None) pooling_params = PoolingParams(use_activation=None)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
...@@ -98,7 +98,7 @@ def test_classify(task): ...@@ -98,7 +98,7 @@ def test_classify(task):
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
invalid_parameters = embed_parameters + step_pooling_parameters invalid_parameters = embed_parameters + step_pooling_parameters
for p in invalid_parameters: for p in set(invalid_parameters) - set(classify_parameters):
with pytest.raises(ValueError): with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True}) pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
...@@ -108,23 +108,23 @@ def test_classify(task): ...@@ -108,23 +108,23 @@ def test_classify(task):
def test_token_embed(pooling_type: str): def test_token_embed(pooling_type: str):
task = "token_embed" task = "token_embed"
model_config = MockModelConfig( model_config = MockModelConfig(
pooler_config=PoolerConfig(pooling_type=pooling_type) pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
) )
pooling_params = PoolingParams(normalize=None) pooling_params = PoolingParams(use_activation=None)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(normalize=True) pooling_params = PoolingParams(use_activation=True)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(normalize=False) pooling_params = PoolingParams(use_activation=False)
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
invalid_parameters = classify_parameters invalid_parameters = classify_parameters
if pooling_type != "STEP": if pooling_type != "STEP":
invalid_parameters = classify_parameters + step_pooling_parameters invalid_parameters = classify_parameters + step_pooling_parameters
for p in invalid_parameters: for p in set(invalid_parameters) - set(embed_parameters):
with pytest.raises(ValueError): with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True}) pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
...@@ -134,7 +134,7 @@ def test_token_embed(pooling_type: str): ...@@ -134,7 +134,7 @@ def test_token_embed(pooling_type: str):
def test_token_classify(pooling_type: str): def test_token_classify(pooling_type: str):
task = "token_classify" task = "token_classify"
model_config = MockModelConfig( model_config = MockModelConfig(
pooler_config=PoolerConfig(pooling_type=pooling_type) pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
) )
pooling_params = PoolingParams(use_activation=None) pooling_params = PoolingParams(use_activation=None)
...@@ -150,7 +150,7 @@ def test_token_classify(pooling_type: str): ...@@ -150,7 +150,7 @@ def test_token_classify(pooling_type: str):
if pooling_type != "STEP": if pooling_type != "STEP":
invalid_parameters = embed_parameters + step_pooling_parameters invalid_parameters = embed_parameters + step_pooling_parameters
for p in invalid_parameters: for p in set(invalid_parameters) - set(classify_parameters):
with pytest.raises(ValueError): with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True}) pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task, model_config=model_config) pooling_params.verify(task=task, model_config=model_config)
...@@ -127,7 +127,7 @@ def test_routing_strategy_integration(monkeypatch, device): ...@@ -127,7 +127,7 @@ def test_routing_strategy_integration(monkeypatch, device):
envs.environment_variables[env_name] = lambda s=strategy: s envs.environment_variables[env_name] = lambda s=strategy: s
# Test the select_experts method # Test the select_experts method
topk_weights, topk_ids, _ = fused_moe.select_experts( topk_weights, topk_ids = fused_moe.router.select_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
) )
......
...@@ -10,6 +10,7 @@ from transformers import ( ...@@ -10,6 +10,7 @@ from transformers import (
) )
from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.tokenizers.grok2 import Grok2Tokenizer
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
...@@ -37,6 +38,10 @@ def test_tokenizer_like_protocol(): ...@@ -37,6 +38,10 @@ def test_tokenizer_like_protocol():
assert isinstance(tokenizer, MistralTokenizer) assert isinstance(tokenizer, MistralTokenizer)
_assert_tokenizer_like(tokenizer) _assert_tokenizer_like(tokenizer)
tokenizer = get_tokenizer("xai-org/grok-2", tokenizer_mode="grok2")
assert isinstance(tokenizer, Grok2Tokenizer)
_assert_tokenizer_like(tokenizer)
@pytest.mark.parametrize("tokenizer_name", ["facebook/opt-125m", "gpt2"]) @pytest.mark.parametrize("tokenizer_name", ["facebook/opt-125m", "gpt2"])
def test_tokenizer_revision(tokenizer_name: str): def test_tokenizer_revision(tokenizer_name: str):
......
...@@ -40,7 +40,8 @@ TOKENIZERS = [ ...@@ -40,7 +40,8 @@ TOKENIZERS = [
os.path.join(models_path_prefix, "EleutherAI/gpt-j-6b"), os.path.join(models_path_prefix, "EleutherAI/gpt-j-6b"),
os.path.join(models_path_prefix, "EleutherAI/pythia-70m"), os.path.join(models_path_prefix, "EleutherAI/pythia-70m"),
os.path.join(models_path_prefix, "bigscience/bloom-560m"), os.path.join(models_path_prefix, "bigscience/bloom-560m"),
os.path.join(models_path_prefix, "mosaicml/mpt-7b"), # FIXME: mosaicml/mpt-7b has been deleted
# "mosaicml/mpt-7b",
os.path.join(models_path_prefix, "tiiuae/falcon-7b"), os.path.join(models_path_prefix, "tiiuae/falcon-7b"),
os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct"), os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct"),
os.path.join(models_path_prefix, "codellama/CodeLlama-7b-hf"), os.path.join(models_path_prefix, "codellama/CodeLlama-7b-hf"),
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock
import pytest
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.tool_parsers.functiongemma_tool_parser import FunctionGemmaToolParser
@pytest.fixture
def mock_tokenizer():
tokenizer = MagicMock()
tokenizer.encode.return_value = [1, 2, 3]
tokenizer.get_vocab.return_value = {}
return tokenizer
@pytest.fixture
def parser(mock_tokenizer):
return FunctionGemmaToolParser(mock_tokenizer)
@pytest.fixture
def mock_request():
request = MagicMock(spec=ChatCompletionRequest)
request.tools = []
request.tool_choice = "auto"
return request
class TestExtractToolCalls:
def test_no_tool_calls(self, parser, mock_request):
model_output = "Hello, how can I help you today?"
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is False
assert result.tool_calls == []
assert result.content == model_output
def test_single_tool_call(self, parser, mock_request):
model_output = (
"<start_function_call>call:get_weather{location:<escape>London<escape>}"
"<end_function_call>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "get_weather"
assert '"location": "London"' in result.tool_calls[0].function.arguments
def test_multiple_arguments(self, parser, mock_request):
model_output = (
"<start_function_call>call:get_weather{"
"location:<escape>San Francisco<escape>,"
"unit:<escape>celsius<escape>}"
"<end_function_call>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "get_weather"
args = result.tool_calls[0].function.arguments
assert "San Francisco" in args
assert "celsius" in args
def test_text_before_tool_call(self, parser, mock_request):
model_output = (
"Let me check the weather for you. "
"<start_function_call>call:get_weather{location:<escape>Paris<escape>}"
"<end_function_call>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert result.content == "Let me check the weather for you."
def test_multiple_tool_calls(self, parser, mock_request):
model_output = (
"<start_function_call>call:get_weather{location:<escape>London<escape>}"
"<end_function_call>"
"<start_function_call>call:get_time{timezone:<escape>UTC<escape>}"
"<end_function_call>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert len(result.tool_calls) == 2
assert result.tool_calls[0].function.name == "get_weather"
assert result.tool_calls[1].function.name == "get_time"
class TestParseArguments:
def test_empty_arguments(self, parser):
result = parser._parse_arguments("")
assert result == {}
def test_single_string_argument(self, parser):
result = parser._parse_arguments("city:<escape>Tokyo<escape>")
assert result == {"city": "Tokyo"}
def test_multiple_arguments(self, parser):
args_str = "city:<escape>Tokyo<escape>,country:<escape>Japan<escape>"
result = parser._parse_arguments(args_str)
assert result == {"city": "Tokyo", "country": "Japan"}
def test_numeric_argument(self, parser):
result = parser._parse_arguments("count:<escape>42<escape>")
assert result == {"count": 42}
def test_boolean_argument(self, parser):
result = parser._parse_arguments("enabled:<escape>true<escape>")
assert result == {"enabled": True}
def test_argument_with_spaces(self, parser):
result = parser._parse_arguments("message:<escape>Hello World<escape>")
assert result == {"message": "Hello World"}
class TestAdjustRequest:
def test_skip_special_tokens_disabled(self, parser, mock_request):
mock_request.tools = [{"type": "function", "function": {"name": "test"}}]
mock_request.tool_choice = "auto"
mock_request.skip_special_tokens = True
result = parser.adjust_request(mock_request)
assert result.skip_special_tokens is False
def test_skip_special_tokens_when_tool_choice_none(self, parser, mock_request):
mock_request.tools = [{"type": "function", "function": {"name": "test"}}]
mock_request.tool_choice = "none"
mock_request.skip_special_tokens = True
result = parser.adjust_request(mock_request)
assert result.skip_special_tokens is True
class TestBufferDeltaText:
def test_regular_text_not_buffered(self, parser):
result = parser._buffer_delta_text("hello")
assert result == "hello"
assert parser.buffered_delta_text == ""
def test_complete_tag_flushed(self, parser):
parser.buffered_delta_text = "<start_function_"
result = parser._buffer_delta_text("call>")
assert "<start_function_call>" in result
if __name__ == "__main__":
pytest.main([__file__, "-v"])
...@@ -44,6 +44,33 @@ def assert_tool_calls( ...@@ -44,6 +44,33 @@ def assert_tool_calls(
) )
def run_streaming_sequence(parser, deltas):
"""Helper to simulate a streaming sequence and return results."""
previous_text = ""
previous_token_ids: list[int] = []
results = []
for delta_text, delta_token_ids in deltas:
current_text = previous_text + delta_text
current_token_ids = previous_token_ids + delta_token_ids
result = parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=None,
)
results.append(result)
previous_text = current_text
previous_token_ids = current_token_ids
return results
def test_extract_tool_calls_no_tools(kimi_k2_tool_parser): def test_extract_tool_calls_no_tools(kimi_k2_tool_parser):
model_output = "This is a test" model_output = "This is a test"
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls( extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
...@@ -346,61 +373,32 @@ def test_token_leak_between_section_and_tool_begin(kimi_k2_tool_parser): ...@@ -346,61 +373,32 @@ def test_token_leak_between_section_and_tool_begin(kimi_k2_tool_parser):
tool_call_begin_token_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>") tool_call_begin_token_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
# Simulate streaming sequence: # Simulate streaming sequence:
deltas = [
("I'll help you with that. ", [1, 2, 3]),
("<|tool_calls_section_begin|>", [section_begin_token_id]),
(" spurious text ", [4, 5]),
("<|tool_call_begin|>", [tool_call_begin_token_id]),
]
results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
# Delta 1: "I'll help you with that. " # Delta 1: "I'll help you with that. "
result1 = kimi_k2_tool_parser.extract_tool_calls_streaming( assert results[0] is not None
previous_text="", assert results[0].content == "I'll help you with that. "
current_text="I'll help you with that. ",
delta_text="I'll help you with that. ",
previous_token_ids=[],
current_token_ids=[1, 2, 3], # Regular tokens
delta_token_ids=[1, 2, 3],
request=None,
)
assert result1 is not None
assert result1.content == "I'll help you with that. "
# Delta 2: "<|tool_calls_section_begin|>" # Delta 2: "<|tool_calls_section_begin|>"
prev_ids = [1, 2, 3]
curr_ids = prev_ids + [section_begin_token_id]
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="I'll help you with that. ",
current_text="I'll help you with that. <|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=prev_ids,
current_token_ids=curr_ids,
delta_token_ids=[section_begin_token_id],
request=None,
)
# Section marker should be stripped and suppressed # Section marker should be stripped and suppressed
assert result2 is None or (result2.content is None or result2.content == "") assert results[1] is None or (
results[1].content is None or results[1].content == ""
)
# Delta 3: " spurious text or tokens " (THE LEAK SCENARIO) # Delta 3: " spurious text or tokens " (THE LEAK SCENARIO)
prev_ids = curr_ids
curr_ids = curr_ids + [4, 5]
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="I'll help you with that. <|tool_calls_section_begin|>",
current_text="I'll help you with that. <|tool_calls_section_begin|> spurious text ",
delta_text=" spurious text ",
previous_token_ids=prev_ids,
current_token_ids=curr_ids,
delta_token_ids=[4, 5],
request=None,
)
# CRITICAL: This text should be suppressed, NOT returned as reasoning_delta # CRITICAL: This text should be suppressed, NOT returned as reasoning_delta
assert result3 is None or (result3.content is None or result3.content == "") assert results[2] is None or (
results[2].content is None or results[2].content == ""
)
# Delta 4: "<|tool_call_begin|>..." # Delta 4: "<|tool_call_begin|>..."
prev_ids = curr_ids
curr_ids = curr_ids + [tool_call_begin_token_id]
_result4 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="I'll help you with that. <|tool_calls_section_begin|> spurious text ",
current_text="I'll help you with that. <|tool_calls_section_begin|> spurious text <|tool_call_begin|>",
delta_text="<|tool_call_begin|>",
previous_token_ids=prev_ids,
current_token_ids=curr_ids,
delta_token_ids=[tool_call_begin_token_id],
request=None,
)
# Now we're in tool call mode, result depends on internal state # Now we're in tool call mode, result depends on internal state
# The key is that the spurious text from Delta 3 was not leaked # The key is that the spurious text from Delta 3 was not leaked
...@@ -416,31 +414,15 @@ def test_split_markers_across_deltas(kimi_k2_tool_parser): ...@@ -416,31 +414,15 @@ def test_split_markers_across_deltas(kimi_k2_tool_parser):
"<|tool_calls_section_begin|>" "<|tool_calls_section_begin|>"
) )
# Delta 1: "...reasoning<|tool_calls_sec" # Delta 1: partial token, Delta 2: complete marker
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming( deltas = [
previous_text="Some reasoning", ("<|tool_calls_sec", [3]),
current_text="Some reasoning<|tool_calls_sec", ("tion_begin|> ", [section_begin_token_id, 4]),
delta_text="<|tool_calls_sec", ]
previous_token_ids=[1, 2],
current_token_ids=[1, 2, 3], # Partial token _results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
delta_token_ids=[3],
request=None,
)
# Partial token not recognized yet, might be buffered
# Should return as content or None (depends on implementation)
# Delta 2: "tion_begin|> " (completes the marker)
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Some reasoning<|tool_calls_sec",
current_text="Some reasoning<|tool_calls_section_begin|> ",
delta_text="tion_begin|> ",
previous_token_ids=[1, 2, 3],
current_token_ids=[1, 2, section_begin_token_id, 4],
delta_token_ids=[section_begin_token_id, 4],
request=None,
)
# Now the complete marker should be detected via buffer # Now the complete marker should be detected via buffer
# The parser should enter tool section mode
assert kimi_k2_tool_parser.in_tool_section is True assert kimi_k2_tool_parser.in_tool_section is True
...@@ -475,42 +457,17 @@ def test_reentry_to_reasoning_after_tool_section(kimi_k2_tool_parser): ...@@ -475,42 +457,17 @@ def test_reentry_to_reasoning_after_tool_section(kimi_k2_tool_parser):
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>") section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>") section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
# Enter tool section deltas = [
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming( ("<|tool_calls_section_begin|>", [section_begin_id]),
previous_text="", ("<|tool_calls_section_end|>", [section_end_id]),
current_text="<|tool_calls_section_begin|>", (" More reasoning", [10, 11]),
delta_text="<|tool_calls_section_begin|>", ]
previous_token_ids=[],
current_token_ids=[section_begin_id],
delta_token_ids=[section_begin_id],
request=None,
)
assert kimi_k2_tool_parser.in_tool_section is True
# Exit tool section results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="<|tool_calls_section_begin|>",
current_text="<|tool_calls_section_begin|><|tool_calls_section_end|>",
delta_text="<|tool_calls_section_end|>",
previous_token_ids=[section_begin_id],
current_token_ids=[section_begin_id, section_end_id],
delta_token_ids=[section_end_id],
request=None,
)
assert kimi_k2_tool_parser.in_tool_section is False
# Subsequent reasoning text should be returned normally assert kimi_k2_tool_parser.in_tool_section is False
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming( assert results[2] is not None
previous_text="<|tool_calls_section_begin|><|tool_calls_section_end|>", assert results[2].content == " More reasoning"
current_text="<|tool_calls_section_begin|><|tool_calls_section_end|> More reasoning",
delta_text=" More reasoning",
previous_token_ids=[section_begin_id, section_end_id],
current_token_ids=[section_begin_id, section_end_id, 10, 11],
delta_token_ids=[10, 11],
request=None,
)
assert result3 is not None
assert result3.content == " More reasoning"
def test_empty_tool_section(kimi_k2_tool_parser): def test_empty_tool_section(kimi_k2_tool_parser):
...@@ -819,106 +776,150 @@ def test_tool_call_end_and_section_end_same_chunk(kimi_k2_tool_parser): ...@@ -819,106 +776,150 @@ def test_tool_call_end_and_section_end_same_chunk(kimi_k2_tool_parser):
tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>") tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>")
# Simulate a streaming sequence for a SHORT tool call (all in one chunk): # Simulate a streaming sequence for a SHORT tool call (all in one chunk):
# 1. Reasoning text
result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="Let me help. ",
delta_text="Let me help. ",
previous_token_ids=[],
current_token_ids=[1, 2],
delta_token_ids=[1, 2],
request=None,
)
assert result1 is not None
assert result1.content == "Let me help. "
# 2. Section begin
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Let me help. ",
current_text="Let me help. <|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=[1, 2],
current_token_ids=[1, 2, section_begin_id],
delta_token_ids=[section_begin_id],
request=None,
)
assert kimi_k2_tool_parser.in_tool_section is True
# 3. Tool call begin + full content + tool_end + section_end ALL IN ONE CHUNK
# This is the critical scenario for short tool calls
combined = ( combined = (
'<|tool_call_begin|>get_weather:0 <|tool_call_argument_begin|> {"city": "Paris"} ' '<|tool_call_begin|>get_weather:0 <|tool_call_argument_begin|> {"city": "Paris"} '
"<|tool_call_end|><|tool_calls_section_end|>" "<|tool_call_end|><|tool_calls_section_end|>"
) )
# Build up the previous text gradually to simulate realistic streaming deltas = [
prev_text = "Let me help. <|tool_calls_section_begin|>" ("Let me help. ", [1, 2]),
curr_text = prev_text + combined ("<|tool_calls_section_begin|>", [section_begin_id]),
(combined, [tool_begin_id, 10, 11, 12, tool_end_id, section_end_id]),
(" Done", [20]),
]
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming( results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
previous_text=prev_text,
current_text=curr_text,
delta_text=combined,
previous_token_ids=[1, 2, section_begin_id],
current_token_ids=[
1,
2,
section_begin_id,
tool_begin_id,
10,
11,
12,
tool_end_id,
section_end_id,
],
delta_token_ids=[tool_begin_id, 10, 11, 12, tool_end_id, section_end_id],
request=None,
)
# CRITICAL: Parser should have exited section AFTER processing tool # CRITICAL: Parser should have exited section AFTER processing tool
assert kimi_k2_tool_parser.in_tool_section is False assert kimi_k2_tool_parser.in_tool_section is False
# Tool call should have been emitted (not dropped) # Tool call should have been emitted (not dropped)
# The result might be the tool name or None depending on state, but if results[2] is not None and results[2].content is not None:
# importantly, it shouldn't be returning the literal tokens as content
if result3 is not None and result3.content is not None:
# Verify no special tokens leaked into content # Verify no special tokens leaked into content
assert "<|tool_call_end|>" not in result3.content assert "<|tool_call_end|>" not in results[2].content
assert "<|tool_calls_section_end|>" not in result3.content assert "<|tool_calls_section_end|>" not in results[2].content
# 4. Verify subsequent content streams normally # Content after tool section should stream normally
result4 = kimi_k2_tool_parser.extract_tool_calls_streaming( assert results[3] is not None
previous_text=curr_text, assert results[3].content == " Done"
current_text=curr_text + " Done",
delta_text=" Done",
previous_token_ids=[ def test_streaming_tool_call_markers_not_leaked(kimi_k2_tool_parser):
1, """
2, CRITICAL TEST: Verify that tool call markers (<|tool_call_begin|>,
section_begin_id, <|tool_call_end|>, <|tool_call_argument_begin|>) are NOT leaked
tool_begin_id, into the content field during streaming.
10,
11, This reproduces the AWS Bedrock bug where tool call markers appeared
12, in the 'text' field of responses.
tool_end_id, """
section_end_id, kimi_k2_tool_parser.reset_streaming_state()
],
current_token_ids=[ section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
1, section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
2, tool_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
section_begin_id, tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>")
tool_begin_id,
10, # List of markers that should NEVER appear in content
11, forbidden_markers = [
12, "<|tool_call_begin|>",
tool_end_id, "<|tool_call_end|>",
section_end_id, "<|tool_call_argument_begin|>",
20, "<|tool_calls_section_begin|>",
], "<|tool_calls_section_end|>",
delta_token_ids=[20], ]
request=None,
all_content = []
# Steps: reasoning, section begin, tool call, section end, more reasoning
tool_chunk = (
"<|tool_call_begin|> functions.get_weather:0 "
'<|tool_call_argument_begin|> {"city": "Tokyo"} <|tool_call_end|>'
) )
deltas = [
("I'll check the weather. ", [1, 2, 3]),
("<|tool_calls_section_begin|>", [section_begin_id]),
(tool_chunk, [tool_begin_id, 10, 11, tool_end_id]),
("<|tool_calls_section_end|>", [section_end_id]),
(" Here's the result.", [20, 21]),
]
results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
for res in results:
if res and res.content:
all_content.append(res.content)
# CRITICAL ASSERTIONS: No forbidden markers in any content
full_content = "".join(all_content)
for marker in forbidden_markers:
assert marker not in full_content, (
f"MARKER LEAK DETECTED: '{marker}' found in content. "
f"Full content: {repr(full_content)}"
)
# Content after tool section should stream normally # Also check that tool call content (function name, arguments) is not leaked
assert result4 is not None assert "get_weather" not in full_content, (
assert result4.content == " Done" f"TOOL CALL CONTENT LEAKED: 'get_weather' found in content. "
f"Full content: {repr(full_content)}"
)
assert "Tokyo" not in full_content, (
f"TOOL CALL CONTENT LEAKED: 'Tokyo' found in content. "
f"Full content: {repr(full_content)}"
)
# Verify that legitimate content was preserved
assert "I'll check the weather." in full_content or len(all_content) > 0
def test_streaming_multiple_tool_calls_not_leaked(kimi_k2_tool_parser):
"""
Test that MULTIPLE tool calls in streaming mode do not leak into content.
This reproduces the AWS Bedrock scenario: "Compare weather in Tokyo and NYC".
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
tool_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>")
all_content = []
tool1 = '<|tool_call_begin|> get_weather:0 <|tool_call_argument_begin|> {"city": "Tokyo"} <|tool_call_end|>'
tool2 = ' <|tool_call_begin|> get_weather:1 <|tool_call_argument_begin|> {"city": "New York"} <|tool_call_end|>'
deltas = [
("I'll compare the weather. ", [1, 2, 3]),
("<|tool_calls_section_begin|>", [section_begin_id]),
(tool1, [tool_begin_id, 10, tool_end_id]),
(tool2, [tool_begin_id, 20, tool_end_id]),
("<|tool_calls_section_end|>", [section_end_id]),
(" Here's the comparison.", [30]),
]
results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
for res in results:
if res and res.content:
all_content.append(res.content)
# Assertions
full_content = "".join(all_content)
# Check no markers leaked
forbidden = ["<|tool_call", "<|tool_calls_section"]
for marker in forbidden:
assert marker not in full_content, (
f"MARKER LEAKED: {marker} in {repr(full_content)}"
)
# Check no tool call content leaked (both tools)
assert "get_weather" not in full_content, f"TOOL NAME LEAKED: {repr(full_content)}"
assert "Tokyo" not in full_content, f"TOOL ARG LEAKED (Tokyo): {repr(full_content)}"
assert "New York" not in full_content, (
f"TOOL ARG LEAKED (NYC): {repr(full_content)}"
)
# Legitimate content preserved
assert "compare" in full_content.lower() or len(all_content) > 0
...@@ -281,6 +281,8 @@ def test_extract_tool_calls_pre_v11_tokenizer( ...@@ -281,6 +281,8 @@ def test_extract_tool_calls_pre_v11_tokenizer(
"single_tool_add", "single_tool_add",
"single_tool_weather", "single_tool_weather",
"multiple_tool_calls", "multiple_tool_calls",
"complex",
"wrong_json",
], ],
argnames=["model_output", "expected_tool_calls", "expected_content"], argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[ argvalues=[
...@@ -326,6 +328,36 @@ def test_extract_tool_calls_pre_v11_tokenizer( ...@@ -326,6 +328,36 @@ def test_extract_tool_calls_pre_v11_tokenizer(
], ],
None, None,
), ),
(
# Complex
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="bash",
arguments=json.dumps(
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
)[:-2],
)
)
],
"hi{hi",
),
(
# Wrong json
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="bash",
arguments=json.dumps(
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
),
)
)
],
"hi{hi",
),
], ],
) )
def test_extract_tool_calls( def test_extract_tool_calls(
...@@ -673,7 +705,7 @@ def test_extract_tool_calls_streaming( ...@@ -673,7 +705,7 @@ def test_extract_tool_calls_streaming(
), ),
( (
# Complex # Complex
"""[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501 """hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
[ [
ToolCall( ToolCall(
function=FunctionCall( function=FunctionCall(
...@@ -684,7 +716,7 @@ def test_extract_tool_calls_streaming( ...@@ -684,7 +716,7 @@ def test_extract_tool_calls_streaming(
) )
) )
], ],
"", "hi{hi",
), ),
], ],
) )
......
...@@ -151,3 +151,45 @@ async def test_chat_completion_with_tools( ...@@ -151,3 +151,45 @@ async def test_chat_completion_with_tools(
assert chunk.choices[0].finish_reason != "tool_calls" assert chunk.choices[0].finish_reason != "tool_calls"
assert len(chunks) assert len(chunks)
assert "".join(chunks) == output_text assert "".join(chunks) == output_text
# Regression test for https://github.com/vllm-project/vllm/issues/32006
# Engine crash when combining response_format: json_object with
# tool_choice: required
@pytest.mark.asyncio
@pytest.mark.timeout(120)
async def test_response_format_with_tool_choice_required(
client: openai.AsyncOpenAI, server_config: ServerConfig
):
"""
Test that combining response_format: json_object with tool_choice: required
doesn't crash the engine.
Before the fix, this would cause a validation error:
"You can only use one kind of structured outputs constraint but multiple
are specified" because both json_object and json (from tool schema) would
be set in StructuredOutputsParams.
"""
models = await client.models.list()
model_name: str = models.data[0].id
# This combination previously crashed the engine
chat_completion = await client.chat.completions.create(
messages=ensure_system_prompt(
[{"role": "user", "content": "What is the weather in Dallas, Texas?"}],
server_config,
),
temperature=0,
max_completion_tokens=150,
model=model_name,
tools=[WEATHER_TOOL],
tool_choice="required",
response_format={"type": "json_object"},
)
# The fix clears response_format when tool_choice forces tool calling,
# so the request should complete successfully with tool calls
choice = chat_completion.choices[0]
assert choice.finish_reason == "tool_calls"
assert choice.message.tool_calls is not None
assert len(choice.message.tool_calls) > 0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pytest
from vllm.tool_parsers.minimax_m2_tool_parser import (
MinimaxM2ToolParser,
)
pytestmark = pytest.mark.cpu_test
class FakeTokenizer:
"""Minimal fake tokenizer that exposes the attributes used by the
parser: a truthy model_tokenizer marker and a vocab mapping for the
special tokens.
"""
def __init__(self):
self.model_tokenizer = True
# The parser will look up start/end tokens by their literal strings
self.vocab = {
"<minimax:tool_call>": 1,
"</minimax:tool_call>": 2,
}
def get_vocab(self):
return self.vocab
@pytest.fixture
def minimax_m2_tool_parser():
return MinimaxM2ToolParser(FakeTokenizer())
def test_extract_tool_calls_streaming_incremental(minimax_m2_tool_parser):
parser = minimax_m2_tool_parser
parser._reset_streaming_state()
chunks = [
"<minimax:tool_call>",
'<invoke name="get_weather">',
'<parameter name="city">',
"Seattle</parameter>",
"</invoke></minimax:tool_call>",
]
previous = ""
for chunk in chunks:
current = previous + chunk
delta = chunk
parser.extract_tool_calls_streaming(
previous_text=previous,
current_text=current,
delta_text=delta,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
previous = current
assert len(parser.prev_tool_call_arr) == 1
entry = parser.prev_tool_call_arr[0]
assert entry["name"] == "get_weather"
args = entry["arguments"]
assert args["city"] == "Seattle"
def test_streaming_minimax_m2_multiple_invokes(minimax_m2_tool_parser):
parser = minimax_m2_tool_parser
parser._reset_streaming_state()
chunks = [
"<minimax:tool_call>",
'<invoke name="search_web">',
'<parameter name="query_tag">',
'["technology", "events"]</parameter>',
'<parameter name="query_list">',
'["OpenAI", "latest", "release"]</parameter>',
"</invoke>",
'<invoke name="search_web">',
'<parameter name="query_tag">',
'["technology", "events"]</parameter>',
'<parameter name="query_list">',
'["Gemini", "latest", "release"]</parameter>',
"</invoke>",
"</minimax:tool_call>",
]
previous = ""
for chunk in chunks:
current = previous + chunk
delta = chunk
parser.extract_tool_calls_streaming(
previous_text=previous,
current_text=current,
delta_text=delta,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
previous = current
assert len(parser.prev_tool_call_arr) == 2
for entry, expect_model in zip(parser.prev_tool_call_arr, ["OpenAI", "Gemini"]):
assert entry["name"] == "search_web"
args = json.dumps(entry["arguments"])
assert "technology" in args and "events" in args
assert expect_model in args
# check streamed_args_for_tool for serving_chat.py
for index in range(2):
expected_call = parser.prev_tool_call_arr[index].get("arguments", {})
expected_call = json.dumps(expected_call)
actual_call = parser.streamed_args_for_tool[index]
assert expected_call == actual_call
...@@ -311,6 +311,7 @@ def test_streaming_output_valid(output, empty_params, delta_len): ...@@ -311,6 +311,7 @@ def test_streaming_output_valid(output, empty_params, delta_len):
previous_text = current_text previous_text = current_text
assert len(messages) > 0 assert len(messages) > 0
combined_messages = "[" combined_messages = "["
for message in messages: for message in messages:
if message.tool_calls[0].function.name: if message.tool_calls[0].function.name:
...@@ -328,3 +329,35 @@ def test_streaming_output_valid(output, empty_params, delta_len): ...@@ -328,3 +329,35 @@ def test_streaming_output_valid(output, empty_params, delta_len):
combined_messages += "}]" combined_messages += "}]"
assert json.loads(combined_messages) == output assert json.loads(combined_messages) == output
assert json.dumps(json.loads(combined_messages)) == output_json assert json.dumps(json.loads(combined_messages)) == output_json
def test_streaming_output_valid_with_trailing_extra_data():
self = MagicMock()
output = [{"name": "get_current_weather", "parameters": {"city": "Vienna"}}]
output_json = json.dumps(output) + "\nDONE"
previous_text = ""
function_name_returned = False
messages = []
delta_len = 3
for i in range(0, len(output_json), delta_len):
delta_text = output_json[i : i + delta_len]
current_text = previous_text + delta_text
delta_message, function_name_returned = (
OpenAIServingChat.extract_tool_call_required_streaming(
self,
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=function_name_returned,
)
)
if delta_message:
messages.append(delta_message)
previous_text = current_text
assert len(messages) > 0
...@@ -112,6 +112,7 @@ class RemoteOpenAIServer: ...@@ -112,6 +112,7 @@ class RemoteOpenAIServer:
env.update(env_dict) env.update(env_dict)
serve_cmd = ["vllm", "serve", model, *vllm_serve_args] serve_cmd = ["vllm", "serve", model, *vllm_serve_args]
print(f"Launching RemoteOpenAIServer with: {' '.join(serve_cmd)}") print(f"Launching RemoteOpenAIServer with: {' '.join(serve_cmd)}")
print(f"Environment variables: {env}")
self.proc: subprocess.Popen = subprocess.Popen( self.proc: subprocess.Popen = subprocess.Popen(
serve_cmd, serve_cmd,
env=env, env=env,
...@@ -726,13 +727,34 @@ def init_test_distributed_environment( ...@@ -726,13 +727,34 @@ def init_test_distributed_environment(
distributed_init_port: str, distributed_init_port: str,
local_rank: int = -1, local_rank: int = -1,
) -> None: ) -> None:
distributed_init_method = f"tcp://localhost:{distributed_init_port}" # Note: This function is often called from Ray worker processes, so we
init_distributed_environment( # can't rely on pytest fixtures to set the config. We check if the config
world_size=pp_size * tp_size, # is already set and only create a default one if needed.
rank=rank, from vllm.config import (
distributed_init_method=distributed_init_method, VllmConfig,
local_rank=local_rank, get_current_vllm_config_or_none,
set_current_vllm_config,
) )
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
if get_current_vllm_config_or_none() is not None:
# Config already set, use it directly
init_distributed_environment(
world_size=pp_size * tp_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=local_rank,
)
else:
# No config set, create a default one for the test
with set_current_vllm_config(VllmConfig()):
init_distributed_environment(
world_size=pp_size * tp_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=local_rank,
)
ensure_model_parallel_initialized(tp_size, pp_size) ensure_model_parallel_initialized(tp_size, pp_size)
......
...@@ -99,30 +99,18 @@ def _test_stream_thread(main_expected_stream: torch.cuda.Stream): ...@@ -99,30 +99,18 @@ def _test_stream_thread(main_expected_stream: torch.cuda.Stream):
def test_current_stream_multithread(): def test_current_stream_multithread():
from vllm.platforms import current_platform
if not torch.cuda.is_available(): if not torch.cuda.is_available():
pytest.skip("CUDA not available") pytest.skip("CUDA not available")
if current_platform.is_rocm(): main_dedicated_stream = current_stream()
main_dedicated_stream = current_stream()
assert main_dedicated_stream.cuda_stream != 0, (
"ROCm should create a dedicated stream, not use default stream (0x0)"
)
main_stream_again = current_stream()
assert main_stream_again == main_dedicated_stream, (
"Multiple calls to current_stream should return the same dedicated stream"
)
_test_stream_thread(main_dedicated_stream) assert main_dedicated_stream.cuda_stream != 0, (
else: "ROCm/CUDA should create a dedicated stream, not use default stream (0x0)"
main_default_stream = torch.cuda.default_stream() )
main_initial_stream = current_stream()
assert main_initial_stream == main_default_stream, ( main_stream_again = current_stream()
"First call to current_stream should return default stream on CUDA" assert main_stream_again == main_dedicated_stream, (
) "Multiple calls to current_stream should return the same dedicated stream"
)
_test_stream_thread(main_default_stream) _test_stream_thread(main_dedicated_stream)
...@@ -15,13 +15,17 @@ from tests.v1.attention.utils import ( ...@@ -15,13 +15,17 @@ from tests.v1.attention.utils import (
create_vllm_config, create_vllm_config,
try_get_attention_backend, try_get_attention_backend,
) )
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, is_torch_equal_or_newer from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
is_torch_equal_or_newer,
set_random_seed,
)
from vllm.v1.attention.backend import AttentionType, CommonAttentionMetadata
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
set_kv_cache_layout, set_kv_cache_layout,
) )
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec
...@@ -79,6 +83,13 @@ BATCH_SPECS = { ...@@ -79,6 +83,13 @@ BATCH_SPECS = {
), ),
"single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]), "single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]),
"single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]), "single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]),
# encoder-only
"small_encoder_prefill": BatchSpec(
seq_lens=[32, 64, 128, 256], query_lens=[32, 64, 128, 256]
),
"medium_encoder_prefill": BatchSpec(
seq_lens=[256, 512, 1024, 2048], query_lens=[256, 512, 1024, 2048]
),
} }
...@@ -114,17 +125,17 @@ def create_and_prepopulate_kv_cache( ...@@ -114,17 +125,17 @@ def create_and_prepopulate_kv_cache(
Tuple of (kv_cache, updated_block_table) Tuple of (kv_cache, updated_block_table)
""" """
batch_size = len(k_contexts) batch_size = len(k_contexts)
seq_lens = common_attn_metadata.seq_lens_cpu seq_lens = common_attn_metadata.seq_lens.cpu()
query_lens = ( query_lens = (
common_attn_metadata.query_start_loc_cpu[1:] common_attn_metadata.query_start_loc_cpu[1:]
- common_attn_metadata.query_start_loc_cpu[:-1] - common_attn_metadata.query_start_loc_cpu[:-1]
) )
context_lens = common_attn_metadata.num_computed_tokens_cpu context_lens = seq_lens - query_lens
block_table = common_attn_metadata.block_table_tensor block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping slot_mapping = common_attn_metadata.slot_mapping
# Create KV cache # Create KV cache
kv_cache = torch.empty( kv_cache = torch.zeros(
2, num_blocks, block_size, num_kv_heads, head_size, dtype=dtype, device=device 2, num_blocks, block_size, num_kv_heads, head_size, dtype=dtype, device=device
) )
kv_cache_flat = kv_cache.view(2, -1, num_kv_heads, head_size) kv_cache_flat = kv_cache.view(2, -1, num_kv_heads, head_size)
...@@ -205,6 +216,7 @@ def run_attention_backend( ...@@ -205,6 +216,7 @@ def run_attention_backend(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_type: AttentionType = AttentionType.DECODER,
sliding_window: int | None = None, sliding_window: int | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Run attention computation using the specified backend's AttentionImpl.""" """Run attention computation using the specified backend's AttentionImpl."""
...@@ -272,6 +284,7 @@ def run_attention_backend( ...@@ -272,6 +284,7 @@ def run_attention_backend(
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
alibi_slopes=None, alibi_slopes=None,
sliding_window=sliding_window, sliding_window=sliding_window,
attn_type=attn_type,
kv_cache_dtype="auto", kv_cache_dtype="auto",
) )
...@@ -295,6 +308,7 @@ def _test_backend_correctness( ...@@ -295,6 +308,7 @@ def _test_backend_correctness(
backend_to_test: list[AttentionBackendEnum | str], backend_to_test: list[AttentionBackendEnum | str],
mask_mod, mask_mod,
*, *,
attn_type: AttentionType = AttentionType.DECODER,
block_size: int = 16, block_size: int = 16,
atol: float = 1e-2, atol: float = 1e-2,
rtol: float = 1e-2, rtol: float = 1e-2,
...@@ -320,7 +334,7 @@ def _test_backend_correctness( ...@@ -320,7 +334,7 @@ def _test_backend_correctness(
multiple GPUs. This tests that backends work correctly with different multiple GPUs. This tests that backends work correctly with different
head counts. head counts.
""" """
current_platform.seed_everything(42) set_random_seed(42)
hf_config_override = None hf_config_override = None
if tensor_parallel_size > 1: if tensor_parallel_size > 1:
...@@ -432,6 +446,9 @@ def _test_backend_correctness( ...@@ -432,6 +446,9 @@ def _test_backend_correctness(
common_attn_metadata = create_common_attn_metadata( common_attn_metadata = create_common_attn_metadata(
batch_spec, vllm_config.cache_config.block_size, device batch_spec, vllm_config.cache_config.block_size, device
) )
if attn_type == AttentionType.ENCODER_ONLY:
# For encoder-only, all tokens are prefill tokens
common_attn_metadata.causal = False
# 3. Simulate Paged KV Cache and a realistic slot_mapping # 3. Simulate Paged KV Cache and a realistic slot_mapping
kv_cache = create_and_prepopulate_kv_cache( kv_cache = create_and_prepopulate_kv_cache(
...@@ -487,6 +504,7 @@ def _test_backend_correctness( ...@@ -487,6 +504,7 @@ def _test_backend_correctness(
value_vllm, value_vllm,
kv_cache_for_backend, kv_cache_for_backend,
sliding_window=sliding_window, sliding_window=sliding_window,
attn_type=attn_type,
) )
finally: finally:
if reset_kv_cache_layout: if reset_kv_cache_layout:
...@@ -537,7 +555,7 @@ def _test_backend_correctness( ...@@ -537,7 +555,7 @@ def _test_backend_correctness(
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) @pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
def test_causal_backend_correctness( def test_causal_backend_correctness(
batch_spec_name: str, model: str, tensor_parallel_size: int default_vllm_config, batch_spec_name: str, model: str, tensor_parallel_size: int
): ):
"""Test backend's correctness with causal attention.""" """Test backend's correctness with causal attention."""
...@@ -557,9 +575,21 @@ def test_causal_backend_correctness( ...@@ -557,9 +575,21 @@ def test_causal_backend_correctness(
if is_torch_equal_or_newer("2.9.0.dev0") if is_torch_equal_or_newer("2.9.0.dev0")
else [] else []
) )
SMALL_BLOCK_BACKENDS = [
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS if current_platform.is_rocm():
] SMALL_BLOCK_BACKENDS = [
x
for x in BACKENDS_TO_TEST
if (
x not in LARGE_BLOCK_BACKENDS
and x is not AttentionBackendEnum.FLASH_ATTN
)
]
else:
SMALL_BLOCK_BACKENDS = [
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
]
_test_backend_correctness( _test_backend_correctness(
batch_spec, batch_spec,
model, model,
...@@ -580,12 +610,20 @@ def test_causal_backend_correctness( ...@@ -580,12 +610,20 @@ def test_causal_backend_correctness(
) )
SLIDING_WINDOW_BACKENDS_TO_TEST = [ if current_platform.is_rocm():
AttentionBackendEnum.FLASH_ATTN, # FLASH_ATTN is not supported on ROCm
AttentionBackendEnum.FLEX_ATTENTION, SLIDING_WINDOW_BACKENDS_TO_TEST = [
AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.FLEX_ATTENTION,
"FLEX_ATTENTION_SLOW", AttentionBackendEnum.TRITON_ATTN,
] "FLEX_ATTENTION_SLOW",
]
else:
SLIDING_WINDOW_BACKENDS_TO_TEST = [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.FLEX_ATTENTION,
AttentionBackendEnum.TRITON_ATTN,
"FLEX_ATTENTION_SLOW",
]
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -652,3 +690,45 @@ def test_sliding_window_backend_correctness( ...@@ -652,3 +690,45 @@ def test_sliding_window_backend_correctness(
block_size=128, block_size=128,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
) )
@pytest.mark.parametrize(
"batch_spec_name",
[
"small_encoder_prefill",
"medium_encoder_prefill",
],
)
@pytest.mark.parametrize("model", ["google/embeddinggemma-300m"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
def test_sliding_window_encoder_backend_correctness(
batch_spec_name: str, model: str, tensor_parallel_size: int
):
"""Test backend's correctness with sliding window attention."""
def bidi_sliding_window_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
*,
context_len: int,
sliding_window: int,
):
return torch.abs(q_idx + context_len - kv_idx) < sliding_window
batch_spec = BATCH_SPECS[batch_spec_name]
model_config = ModelConfig(model=model, max_model_len=max(batch_spec.seq_lens))
sliding_window = model_config.get_sliding_window()
sliding_window_mask_mod_fn = partial(
bidi_sliding_window_mask_mod, sliding_window=sliding_window
)
_test_backend_correctness(
batch_spec,
model,
SLIDING_WINDOW_BACKENDS_TO_TEST,
sliding_window_mask_mod_fn,
attn_type=AttentionType.ENCODER_ONLY,
tensor_parallel_size=tensor_parallel_size,
)
...@@ -79,7 +79,12 @@ from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend ...@@ -79,7 +79,12 @@ from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
], ],
) )
def test_mamba_layers_get_attn_backend( def test_mamba_layers_get_attn_backend(
dist_init, layer_class, init_kwargs, expected_backend, expected_mamba_type default_vllm_config,
dist_init,
layer_class,
init_kwargs,
expected_backend,
expected_mamba_type,
): ):
"""Test that Mamba-like layers return the correct attention backend.""" """Test that Mamba-like layers return the correct attention backend."""
layer = layer_class(**init_kwargs) layer = layer_class(**init_kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment