Commit 7e63ef82 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0' into v0.14.0-dev

parents 8cbcac5d b17039bc
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for AITER MLA FP8 support detection.
These tests verify that the _check_aiter_mla_fp8_support() function
correctly handles various error conditions without crashing.
"""
from unittest.mock import patch
import pytest
class TestAiterMlaFp8SupportCheck:
"""Test cases for _check_aiter_mla_fp8_support() function."""
def setup_method(self):
"""Reset the global cache before each test."""
import vllm._aiter_ops as aiter_ops
aiter_ops._AITER_MLA_SUPPORTS_FP8 = None
@patch("vllm._aiter_ops.is_aiter_found_and_supported", return_value=True)
def test_import_error_handling(self, mock_supported):
"""Test that ImportError is handled gracefully."""
import vllm._aiter_ops as aiter_ops
from vllm._aiter_ops import _check_aiter_mla_fp8_support
aiter_ops._AITER_MLA_SUPPORTS_FP8 = None
# Should return False without raising
with patch(
"vllm._aiter_ops.inspect.signature",
side_effect=ImportError("No module"),
):
result = _check_aiter_mla_fp8_support()
assert result is False
@patch("vllm._aiter_ops.is_aiter_found_and_supported", return_value=True)
def test_module_not_found_error_handling(self, mock_supported):
"""Test that ModuleNotFoundError is handled gracefully."""
import vllm._aiter_ops as aiter_ops
from vllm._aiter_ops import _check_aiter_mla_fp8_support
aiter_ops._AITER_MLA_SUPPORTS_FP8 = None
with patch(
"vllm._aiter_ops.inspect.signature",
side_effect=ModuleNotFoundError("Module not found"),
):
# Should return False without raising
assert _check_aiter_mla_fp8_support() is False
# Cache should be set to False
assert aiter_ops._AITER_MLA_SUPPORTS_FP8 is False
@patch("vllm._aiter_ops.is_aiter_found_and_supported", return_value=True)
def test_attribute_error_handling(self, mock_supported):
"""Test that AttributeError is handled gracefully."""
import vllm._aiter_ops as aiter_ops
from vllm._aiter_ops import _check_aiter_mla_fp8_support
aiter_ops._AITER_MLA_SUPPORTS_FP8 = None
with patch(
"vllm._aiter_ops.inspect.signature",
side_effect=AttributeError("No attribute"),
):
assert _check_aiter_mla_fp8_support() is False
assert aiter_ops._AITER_MLA_SUPPORTS_FP8 is False
@patch("vllm._aiter_ops.is_aiter_found_and_supported", return_value=True)
def test_value_error_handling(self, mock_supported):
"""Test that ValueError is handled gracefully (no signature)."""
import vllm._aiter_ops as aiter_ops
from vllm._aiter_ops import _check_aiter_mla_fp8_support
aiter_ops._AITER_MLA_SUPPORTS_FP8 = None
with patch(
"vllm._aiter_ops.inspect.signature",
side_effect=ValueError("No signature"),
):
assert _check_aiter_mla_fp8_support() is False
assert aiter_ops._AITER_MLA_SUPPORTS_FP8 is False
@patch("vllm._aiter_ops.is_aiter_found_and_supported", return_value=True)
def test_type_error_handling(self, mock_supported):
"""Test that TypeError is handled gracefully (not callable)."""
import vllm._aiter_ops as aiter_ops
from vllm._aiter_ops import _check_aiter_mla_fp8_support
aiter_ops._AITER_MLA_SUPPORTS_FP8 = None
with patch(
"vllm._aiter_ops.inspect.signature",
side_effect=TypeError("Not a callable"),
):
assert _check_aiter_mla_fp8_support() is False
assert aiter_ops._AITER_MLA_SUPPORTS_FP8 is False
@patch("vllm._aiter_ops.is_aiter_found_and_supported", return_value=True)
def test_result_caching(self, mock_supported):
"""Test that the result is cached after first check."""
import vllm._aiter_ops as aiter_ops
# Set cache to True
aiter_ops._AITER_MLA_SUPPORTS_FP8 = True
from vllm._aiter_ops import _check_aiter_mla_fp8_support
# Should return cached value without re-checking
result = _check_aiter_mla_fp8_support()
assert result is True
if __name__ == "__main__":
pytest.main([__file__, "-v"])
......@@ -5,9 +5,6 @@
# The utility function cannot be placed in `vllm.utils`
# this needs to be a standalone script
import sys
from contextlib import nullcontext
from vllm_test_utils import BlameResult, blame
# List of modules that should not be imported too early.
# Lazy import `torch._inductor.async_compile` to avoid creating
......@@ -16,26 +13,10 @@ from vllm_test_utils import BlameResult, blame
# `cv2` can easily mess up the environment.
module_names = ["torch._inductor.async_compile", "cv2"]
# set all modules in `module_names` to be None.
# if we import any modules during `import vllm`, there would be a
# hard error and nice stacktrace on the first import.
for module_name in module_names:
sys.modules[module_name] = None # type: ignore[assignment]
def any_module_imported():
return any(module_name in sys.modules for module_name in module_names)
# In CI, we only check finally if the module is imported.
# If it is indeed imported, we can rerun the test with `use_blame=True`,
# which will trace every function call to find the first import location,
# and help find the root cause.
# We don't run it in CI by default because it is slow.
use_blame = False
context = blame(any_module_imported) if use_blame else nullcontext()
with context as result:
import vllm # noqa
if use_blame:
assert isinstance(result, BlameResult)
print(f"the first import location is:\n{result.trace_stack}")
assert not any_module_imported(), (
f"Some the modules in {module_names} are imported. To see the first"
f" import location, run the test with `use_blame=True`."
)
import vllm # noqa
......@@ -18,25 +18,37 @@ for i in {1..5}; do
echo "Checking metadata.json URL (attempt $i)..."
if curl --fail "$meta_json_url" > metadata.json; then
echo "INFO: metadata.json URL is valid."
# check whether it is valid json by python
# check whether it is valid json by python (printed to stdout)
if python3 -m json.tool metadata.json; then
echo "INFO: metadata.json is valid JSON. Proceeding with the test."
echo "INFO: metadata.json is valid JSON. Proceeding with the check."
# check whether there is an object in the json matching:
# "package_name": "vllm", and "platform_tag" matches the current architecture
# see `determine_wheel_url` in setup.py for more details
if python3 -c "import platform as p,json as j,sys as s; d = j.load(open('metadata.json')); \
s.exit(int(not any(o.get('package_name') == 'vllm' and p.machine() in o.get('platform_tag') \
for o in d)))" 2>/dev/null; then
echo "INFO: metadata.json contains a pre-compiled wheel for the current architecture."
break
else
echo "WARN: metadata.json does not have a pre-compiled wheel for the current architecture."
fi
else
echo "CRITICAL: metadata.json exists but is not valid JSON, please do report in #sig-ci channel!"
echo "INFO: metadata.json content:"
cat metadata.json
exit 1
fi
break
fi
# failure handling
# failure handling & retry logic
if [ $i -eq 5 ]; then
echo "ERROR: metadata.json URL is still not valid after 5 attempts."
echo "ERROR: Please check whether the precompiled wheel for commit $merge_base_commit exists."
echo "ERROR: metadata is still not available after 5 attempts."
echo "ERROR: Please check whether the precompiled wheel for commit $merge_base_commit is available."
echo " NOTE: If $merge_base_commit is a new commit on main, maybe try again after its release pipeline finishes."
echo " NOTE: If it fails, please report in #sig-ci channel."
exit 1
else
echo "WARNING: metadata.json URL is not valid. Retrying in 3 minutes..."
sleep 180
echo "WARNING: metadata is not available. Retrying after 5 minutes..."
sleep 300
fi
done
......
......@@ -4,6 +4,11 @@
set -e
set -x
if command -v rocminfo >/dev/null 2>&1; then
echo "Skipping test for ROCm platform"
exit 0
fi
cd /vllm-workspace/
rm -rf .venv
......@@ -36,7 +41,7 @@ if diff before.txt after.txt; then
echo "torch version not overridden."
else
echo "torch version overridden by nightly_torch_test.txt, \
if the dependency is not triggered by the pytroch nightly test,\
if the dependency is not triggered by the pytorch nightly test,\
please add the dependency to the list 'white_list' in tools/pre_commit/generate_nightly_torch_test.py"
exit 1
fi
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
)
from vllm.v1.attention.backends.registry import (
AttentionBackendEnum,
MambaAttentionBackendEnum,
register_backend,
)
class CustomAttentionImpl(AttentionImpl):
"""Mock custom attention implementation for testing."""
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, *args, **kwargs):
"""Mock forward pass."""
pass
class CustomAttentionBackend(AttentionBackend):
"""Mock custom attention backend for testing."""
@staticmethod
def get_name():
return "CUSTOM"
@staticmethod
def get_impl_cls():
return CustomAttentionImpl
@staticmethod
def get_builder_cls():
"""Mock builder class."""
return None
@staticmethod
def get_required_kv_cache_layout():
"""Mock KV cache layout."""
return None
class CustomMambaAttentionImpl(AttentionImpl):
"""Mock custom mamba attention implementation for testing."""
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, *args, **kwargs):
"""Mock forward pass."""
pass
class CustomMambaAttentionBackend(AttentionBackend):
"""Mock custom mamba attention backend for testing."""
@staticmethod
def get_name():
return "CUSTOM_MAMBA"
@staticmethod
def get_impl_cls():
return CustomMambaAttentionImpl
@staticmethod
def get_builder_cls():
"""Mock builder class."""
return None
@staticmethod
def get_required_kv_cache_layout():
"""Mock KV cache layout."""
return None
def test_custom_is_not_alias_of_any_backend():
# Get all members of AttentionBackendEnum
all_backends = list(AttentionBackendEnum)
# Find any aliases of CUSTOM
aliases = []
for backend in all_backends:
if backend.name != "CUSTOM" and backend is AttentionBackendEnum.CUSTOM:
aliases.append(backend.name)
# CUSTOM should not be an alias of any other backend
assert len(aliases) == 0, (
f"BUG! CUSTOM is an alias of: {', '.join(aliases)}!\n"
f"CUSTOM.value = {repr(AttentionBackendEnum.CUSTOM.value)}\n"
f"This happens when CUSTOM has the same value as another backend.\n"
f"When you register to CUSTOM, you're actually registering to {aliases[0]}!\n"
f"All backend values:\n"
+ "\n".join(f" {b.name}: {repr(b.value)}" for b in all_backends)
)
# Verify CUSTOM has its own unique identity
assert AttentionBackendEnum.CUSTOM.name == "CUSTOM", (
f"CUSTOM.name should be 'CUSTOM', but got '{AttentionBackendEnum.CUSTOM.name}'"
)
def test_register_custom_backend_with_class_path():
# Register with explicit class path
register_backend(
backend=AttentionBackendEnum.CUSTOM,
class_path="tests.test_attention_backend_registry.CustomAttentionBackend",
is_mamba=False,
)
# Check that CUSTOM backend is registered
assert AttentionBackendEnum.CUSTOM.is_overridden(), (
"CUSTOM should be overridden after registration"
)
# Get the registered class path
class_path = AttentionBackendEnum.CUSTOM.get_path()
assert class_path == "tests.test_attention_backend_registry.CustomAttentionBackend"
# Get the backend class
backend_cls = AttentionBackendEnum.CUSTOM.get_class()
assert backend_cls.get_name() == "CUSTOM"
assert backend_cls.get_impl_cls() == CustomAttentionImpl
def test_mamba_custom_is_not_alias_of_any_backend():
# Get all mamba backends
all_backends = list(MambaAttentionBackendEnum)
# Find any aliases of CUSTOM
aliases = []
for backend in all_backends:
if backend.name != "CUSTOM" and backend is MambaAttentionBackendEnum.CUSTOM:
aliases.append(backend.name)
# CUSTOM should not be an alias of any other backend
assert len(aliases) == 0, (
f"BUG! MambaAttentionBackendEnum.CUSTOM is an alias of: {', '.join(aliases)}!\n"
f"CUSTOM.value = {repr(MambaAttentionBackendEnum.CUSTOM.value)}\n"
f"All mamba backend values:\n"
+ "\n".join(f" {b.name}: {repr(b.value)}" for b in all_backends)
)
def test_register_custom_mamba_backend_with_class_path():
# Register with explicit class path
register_backend(
backend=MambaAttentionBackendEnum.CUSTOM,
class_path="tests.test_attention_backend_registry.CustomMambaAttentionBackend",
is_mamba=True,
)
# Check that the backend is registered
assert MambaAttentionBackendEnum.CUSTOM.is_overridden()
# Get the registered class path
class_path = MambaAttentionBackendEnum.CUSTOM.get_path()
assert (
class_path
== "tests.test_attention_backend_registry.CustomMambaAttentionBackend"
)
# Get the backend class
backend_cls = MambaAttentionBackendEnum.CUSTOM.get_class()
assert backend_cls.get_name() == "CUSTOM_MAMBA"
assert backend_cls.get_impl_cls() == CustomMambaAttentionImpl
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import os
from dataclasses import MISSING, Field, asdict, dataclass, field
......@@ -25,7 +26,6 @@ from vllm.config.vllm import (
OPTIMIZATION_LEVEL_TO_CONFIG,
OptimizationLevel,
)
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
from utils import models_path_prefix
......@@ -162,8 +162,9 @@ def test_get_pooling_config():
model_config = ModelConfig(model_id)
assert model_config.pooler_config is not None
assert model_config.pooler_config.normalize
assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name
assert model_config.pooler_config.use_activation
assert model_config.pooler_config.seq_pooling_type == "MEAN"
assert model_config.pooler_config.tok_pooling_type == "ALL"
@pytest.mark.skipif(
......@@ -171,7 +172,7 @@ def test_get_pooling_config():
)
def test_get_pooling_config_from_args():
model_id = os.path.join(models_path_prefix, "sentence-transformers/all-MiniLM-L12-v2")
pooler_config = PoolerConfig(pooling_type="CLS", normalize=True)
pooler_config = PoolerConfig(seq_pooling_type="CLS", normalize=True)
model_config = ModelConfig(model_id, pooler_config=pooler_config)
assert asdict(model_config.pooler_config) == asdict(pooler_config)
......@@ -182,14 +183,25 @@ def test_get_pooling_config_from_args():
[
("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "LAST", "LAST"), # LLM
("intfloat/e5-small", "CLS", "MEAN"), # BertModel
],
)
def test_default_seq_pooling_type(model_id, default_pooling_type, pooling_type):
model_config = ModelConfig(model_id)
assert model_config._model_info.default_seq_pooling_type == default_pooling_type
assert model_config.pooler_config.seq_pooling_type == pooling_type
@pytest.mark.parametrize(
("model_id", "default_pooling_type", "pooling_type"),
[
("Qwen/Qwen2.5-Math-RM-72B", "ALL", "ALL"), # reward
("Qwen/Qwen2.5-Math-PRM-7B", "STEP", "STEP"), # step reward
],
)
def test_default_pooling_type(model_id, default_pooling_type, pooling_type):
def test_default_tok_pooling_type(model_id, default_pooling_type, pooling_type):
model_config = ModelConfig(model_id)
assert model_config._model_info.default_pooling_type == default_pooling_type
assert model_config.pooler_config.pooling_type == pooling_type
assert model_config._model_info.default_tok_pooling_type == default_pooling_type
assert model_config.pooler_config.tok_pooling_type == pooling_type
@pytest.mark.parametrize(
......@@ -207,8 +219,8 @@ def test_default_pooling_type(model_id, default_pooling_type, pooling_type):
)
def test_moe_model_detection(model_id, expected_is_moe_model):
model_config = ModelConfig(model_id)
# Just check that is_moe_model field exists and is a boolean
assert model_config.is_model_moe() == expected_is_moe_model
# Just check that is_moe field exists and is a boolean
assert model_config.is_moe == expected_is_moe_model
@pytest.mark.parametrize(
......@@ -226,7 +238,7 @@ def test_moe_model_detection(model_id, expected_is_moe_model):
def test_is_quantized(model_id, quantized):
model_config = ModelConfig(model_id)
# Just check that quantized field exists and is a boolean
assert model_config.is_quantized() == quantized
assert model_config.is_quantized == quantized
@pytest.mark.skipif(
......@@ -556,100 +568,100 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
"jason9693/Qwen2.5-1.5B-apeach",
"decoder",
True,
"Pooling models with causal attn and last pooling support chunked prefill.",
"Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
),
(
"Qwen/Qwen3-Embedding-0.6B",
"decoder",
True,
"Pooling models with causal attn and last pooling support chunked prefill.",
"Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
),
(
"Qwen/Qwen2.5-Math-PRM-7B",
"decoder",
False,
"Pooling models with step pooling does not support chunked prefill.",
"Pooling models with causal attn and LAST/STEP pooling do not support chunked prefill.", # noqa: E501
),
(
"internlm/internlm2-1_8b-reward",
"decoder",
True,
"Pooling models with causal attn and all pooling support chunked prefill.",
"Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
),
(
"BAAI/bge-base-en",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
),
(
"boltuix/NeuroBERT-NER",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
),
(
"papluca/xlm-roberta-base-language-detection",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
),
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
),
(
"intfloat/e5-small",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
),
# multimodal models
(
"openai/clip-vit-base-patch32",
"decoder",
True,
"Pooling models with causal attn and last pooling support chunked prefill.",
"Pooling models with causal attn and LAST/ALL pooling support chunked prefill.", # noqa: E501
),
(
"google/siglip-base-patch16-224",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
"Pooling models with bidirectional attn do not support chunked prefill.", # noqa: E501
),
# generate models
(
"Qwen/Qwen3-0.6B",
"decoder",
True,
"Generative models support chunked prefill.",
"Generative models support chunked prefill.", # noqa: E501
),
(
"Qwen/Qwen3-Next-80B-A3B-Instruct",
"hybrid",
True,
"Generative models support chunked prefill.",
"Generative models support chunked prefill.", # noqa: E501
),
(
"ibm-granite/granite-4.0-h-small",
"hybrid",
True,
"Generative models support chunked prefill.",
"Generative models support chunked prefill.", # noqa: E501
),
(
"state-spaces/mamba-130m-hf",
"attention_free",
True,
"Generative models support chunked prefill.",
"Generative models support chunked prefill.", # noqa: E501
),
# encoder_decoder models
(
"openai/whisper-small",
"encoder_decoder",
False,
"Encoder decoder models does not support chunked prefill.",
"Encoder decoder models do not support chunked prefill.", # noqa: E501
),
],
)
......@@ -675,100 +687,100 @@ def test_is_chunked_prefill_supported(
"jason9693/Qwen2.5-1.5B-apeach",
"decoder",
True,
"Pooling models with causal attn and last pooling support prefix caching.",
"Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
),
(
"Qwen/Qwen3-Embedding-0.6B",
"decoder",
True,
"Pooling models with causal attn and last pooling support prefix caching.",
"Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
),
(
"Qwen/Qwen2.5-Math-PRM-7B",
"decoder",
False,
"Pooling models with step pooling does not support prefix caching.",
"Pooling models with causal attn and LAST/STEP pooling do not support prefix caching.", # noqa: E501
),
(
"internlm/internlm2-1_8b-reward",
"decoder",
True,
"Pooling models with causal attn and all pooling support prefix caching.",
"Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
),
(
"BAAI/bge-base-en",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
),
(
"boltuix/NeuroBERT-NER",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
),
(
"papluca/xlm-roberta-base-language-detection",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
),
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
),
(
"intfloat/e5-small",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
),
# multimodal models
(
"openai/clip-vit-base-patch32",
"decoder",
True,
"Pooling models with causal attn and last pooling support prefix caching.",
"Pooling models with causal attn and LAST/ALL pooling support prefix caching.", # noqa: E501
),
(
"google/siglip-base-patch16-224",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
"Pooling models with bidirectional attn do not support prefix caching.", # noqa: E501
),
# generate models
(
"Qwen/Qwen3-0.6B",
"decoder",
True,
"Generative models support prefix caching.",
"Generative models support prefix caching.", # noqa: E501
),
(
"Qwen/Qwen3-Next-80B-A3B-Instruct",
"hybrid",
False,
"Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501
"Hybrid models do not support prefix caching since the feature is still experimental.", # noqa: E501
),
(
"ibm-granite/granite-4.0-h-small",
"hybrid",
False,
"Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501
"Hybrid models do not support prefix caching since the feature is still experimental.", # noqa: E501
),
(
"state-spaces/mamba-130m-hf",
"attention_free",
False,
"Attention free models does not support prefix caching since the feature is still experimental.", # noqa: E501
"Attention free models do not support prefix caching since the feature is still experimental.", # noqa: E501
),
# encoder_decoder models
(
"openai/whisper-small",
"encoder_decoder",
False,
"Encoder decoder models does not support prefix caching.",
"Encoder decoder models do not support prefix caching.", # noqa: E501
),
],
)
......@@ -927,7 +939,7 @@ def test_vllm_config_callable_defaults():
model_config=quantized_model, optimization_level=OptimizationLevel.O2
)
enable_if_quantized = lambda cfg: (
cfg.model_config is not None and cfg.model_config.is_quantized()
cfg.model_config is not None and cfg.model_config.is_quantized
)
assert enable_if_quantized(config_quantized) is True
assert enable_if_quantized(config_no_model) is False
......@@ -938,7 +950,7 @@ def test_vllm_config_callable_defaults():
model_config=moe_model, optimization_level=OptimizationLevel.O2
)
enable_if_sequential = lambda cfg: (
cfg.model_config is not None and not cfg.model_config.is_model_moe()
cfg.model_config is not None and not cfg.model_config.is_moe
)
assert enable_if_sequential(config_moe) is False
assert enable_if_sequential(config_quantized) is True
......@@ -1052,3 +1064,46 @@ def test_scheduler_config_init():
with pytest.raises(AttributeError):
# InitVar does not become an attribute
print(SchedulerConfig.default_factory().max_model_len)
@pytest.mark.parametrize(
(
"model_id",
"data_parallel_size",
"external_lb",
"expected_needs_coordinator",
),
[
# Non-MoE model with DP=1 should not need coordinator
("facebook/opt-125m", 1, False, False),
# Non-MoE model with DP>1 internal LB should need coordinator
("facebook/opt-125m", 2, False, True),
# Non-MoE model with DP>1 external LB should not need coordinator
("facebook/opt-125m", 2, True, False),
# MoE model with DP=1 should not need coordinator
("mistralai/Mixtral-8x7B-Instruct-v0.1", 1, False, False),
# MoE model with DP>1 internal LB should need both coordinator
# and wave coordination
("mistralai/Mixtral-8x7B-Instruct-v0.1", 2, False, True),
# MoE model with DP>1 external LB needs coordinator for wave coordination
# (wave coordination runs in coordinator process)
("mistralai/Mixtral-8x7B-Instruct-v0.1", 2, True, True),
],
)
def test_needs_dp_coordination(
model_id,
data_parallel_size,
external_lb,
expected_needs_coordinator,
):
"""Test that DP coordinator and wave coordination are configured correctly."""
from vllm.config import ParallelConfig
model_config = ModelConfig(model_id)
parallel_config = ParallelConfig(
data_parallel_size=data_parallel_size,
data_parallel_external_lb=external_lb,
)
vllm_config = VllmConfig(model_config=model_config, parallel_config=parallel_config)
assert vllm_config.needs_dp_coordinator == expected_needs_coordinator
......@@ -18,7 +18,7 @@ EMBEDDING_MODELS = [
]
classify_parameters = ["use_activation"]
embed_parameters = ["dimensions", "normalize"]
embed_parameters = ["dimensions", "use_activation"]
step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
......@@ -40,19 +40,19 @@ def test_task():
def test_embed():
task = "embed"
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS"))
model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
pooling_params = PoolingParams(normalize=None)
pooling_params = PoolingParams(use_activation=None)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(normalize=True)
pooling_params = PoolingParams(use_activation=True)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(normalize=False)
pooling_params = PoolingParams(use_activation=False)
pooling_params.verify(task=task, model_config=model_config)
invalid_parameters = classify_parameters + step_pooling_parameters
for p in invalid_parameters:
for p in set(invalid_parameters) - set(embed_parameters):
with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task, model_config=model_config)
......@@ -86,7 +86,7 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
@pytest.mark.parametrize("task", ["score", "classify"])
def test_classify(task):
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS"))
model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
pooling_params = PoolingParams(use_activation=None)
pooling_params.verify(task=task, model_config=model_config)
......@@ -98,7 +98,7 @@ def test_classify(task):
pooling_params.verify(task=task, model_config=model_config)
invalid_parameters = embed_parameters + step_pooling_parameters
for p in invalid_parameters:
for p in set(invalid_parameters) - set(classify_parameters):
with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task, model_config=model_config)
......@@ -108,23 +108,23 @@ def test_classify(task):
def test_token_embed(pooling_type: str):
task = "token_embed"
model_config = MockModelConfig(
pooler_config=PoolerConfig(pooling_type=pooling_type)
pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
)
pooling_params = PoolingParams(normalize=None)
pooling_params = PoolingParams(use_activation=None)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(normalize=True)
pooling_params = PoolingParams(use_activation=True)
pooling_params.verify(task=task, model_config=model_config)
pooling_params = PoolingParams(normalize=False)
pooling_params = PoolingParams(use_activation=False)
pooling_params.verify(task=task, model_config=model_config)
invalid_parameters = classify_parameters
if pooling_type != "STEP":
invalid_parameters = classify_parameters + step_pooling_parameters
for p in invalid_parameters:
for p in set(invalid_parameters) - set(embed_parameters):
with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task, model_config=model_config)
......@@ -134,7 +134,7 @@ def test_token_embed(pooling_type: str):
def test_token_classify(pooling_type: str):
task = "token_classify"
model_config = MockModelConfig(
pooler_config=PoolerConfig(pooling_type=pooling_type)
pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
)
pooling_params = PoolingParams(use_activation=None)
......@@ -150,7 +150,7 @@ def test_token_classify(pooling_type: str):
if pooling_type != "STEP":
invalid_parameters = embed_parameters + step_pooling_parameters
for p in invalid_parameters:
for p in set(invalid_parameters) - set(classify_parameters):
with pytest.raises(ValueError):
pooling_params = PoolingParams(**{p: True})
pooling_params.verify(task=task, model_config=model_config)
......@@ -127,7 +127,7 @@ def test_routing_strategy_integration(monkeypatch, device):
envs.environment_variables[env_name] = lambda s=strategy: s
# Test the select_experts method
topk_weights, topk_ids, _ = fused_moe.select_experts(
topk_weights, topk_ids = fused_moe.router.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
......
......@@ -10,6 +10,7 @@ from transformers import (
)
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.tokenizers.grok2 import Grok2Tokenizer
from vllm.tokenizers.mistral import MistralTokenizer
......@@ -37,6 +38,10 @@ def test_tokenizer_like_protocol():
assert isinstance(tokenizer, MistralTokenizer)
_assert_tokenizer_like(tokenizer)
tokenizer = get_tokenizer("xai-org/grok-2", tokenizer_mode="grok2")
assert isinstance(tokenizer, Grok2Tokenizer)
_assert_tokenizer_like(tokenizer)
@pytest.mark.parametrize("tokenizer_name", ["facebook/opt-125m", "gpt2"])
def test_tokenizer_revision(tokenizer_name: str):
......
......@@ -40,7 +40,8 @@ TOKENIZERS = [
os.path.join(models_path_prefix, "EleutherAI/gpt-j-6b"),
os.path.join(models_path_prefix, "EleutherAI/pythia-70m"),
os.path.join(models_path_prefix, "bigscience/bloom-560m"),
os.path.join(models_path_prefix, "mosaicml/mpt-7b"),
# FIXME: mosaicml/mpt-7b has been deleted
# "mosaicml/mpt-7b",
os.path.join(models_path_prefix, "tiiuae/falcon-7b"),
os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct"),
os.path.join(models_path_prefix, "codellama/CodeLlama-7b-hf"),
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock
import pytest
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.tool_parsers.functiongemma_tool_parser import FunctionGemmaToolParser
@pytest.fixture
def mock_tokenizer():
tokenizer = MagicMock()
tokenizer.encode.return_value = [1, 2, 3]
tokenizer.get_vocab.return_value = {}
return tokenizer
@pytest.fixture
def parser(mock_tokenizer):
return FunctionGemmaToolParser(mock_tokenizer)
@pytest.fixture
def mock_request():
request = MagicMock(spec=ChatCompletionRequest)
request.tools = []
request.tool_choice = "auto"
return request
class TestExtractToolCalls:
def test_no_tool_calls(self, parser, mock_request):
model_output = "Hello, how can I help you today?"
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is False
assert result.tool_calls == []
assert result.content == model_output
def test_single_tool_call(self, parser, mock_request):
model_output = (
"<start_function_call>call:get_weather{location:<escape>London<escape>}"
"<end_function_call>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "get_weather"
assert '"location": "London"' in result.tool_calls[0].function.arguments
def test_multiple_arguments(self, parser, mock_request):
model_output = (
"<start_function_call>call:get_weather{"
"location:<escape>San Francisco<escape>,"
"unit:<escape>celsius<escape>}"
"<end_function_call>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "get_weather"
args = result.tool_calls[0].function.arguments
assert "San Francisco" in args
assert "celsius" in args
def test_text_before_tool_call(self, parser, mock_request):
model_output = (
"Let me check the weather for you. "
"<start_function_call>call:get_weather{location:<escape>Paris<escape>}"
"<end_function_call>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert result.content == "Let me check the weather for you."
def test_multiple_tool_calls(self, parser, mock_request):
model_output = (
"<start_function_call>call:get_weather{location:<escape>London<escape>}"
"<end_function_call>"
"<start_function_call>call:get_time{timezone:<escape>UTC<escape>}"
"<end_function_call>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert len(result.tool_calls) == 2
assert result.tool_calls[0].function.name == "get_weather"
assert result.tool_calls[1].function.name == "get_time"
class TestParseArguments:
def test_empty_arguments(self, parser):
result = parser._parse_arguments("")
assert result == {}
def test_single_string_argument(self, parser):
result = parser._parse_arguments("city:<escape>Tokyo<escape>")
assert result == {"city": "Tokyo"}
def test_multiple_arguments(self, parser):
args_str = "city:<escape>Tokyo<escape>,country:<escape>Japan<escape>"
result = parser._parse_arguments(args_str)
assert result == {"city": "Tokyo", "country": "Japan"}
def test_numeric_argument(self, parser):
result = parser._parse_arguments("count:<escape>42<escape>")
assert result == {"count": 42}
def test_boolean_argument(self, parser):
result = parser._parse_arguments("enabled:<escape>true<escape>")
assert result == {"enabled": True}
def test_argument_with_spaces(self, parser):
result = parser._parse_arguments("message:<escape>Hello World<escape>")
assert result == {"message": "Hello World"}
class TestAdjustRequest:
def test_skip_special_tokens_disabled(self, parser, mock_request):
mock_request.tools = [{"type": "function", "function": {"name": "test"}}]
mock_request.tool_choice = "auto"
mock_request.skip_special_tokens = True
result = parser.adjust_request(mock_request)
assert result.skip_special_tokens is False
def test_skip_special_tokens_when_tool_choice_none(self, parser, mock_request):
mock_request.tools = [{"type": "function", "function": {"name": "test"}}]
mock_request.tool_choice = "none"
mock_request.skip_special_tokens = True
result = parser.adjust_request(mock_request)
assert result.skip_special_tokens is True
class TestBufferDeltaText:
def test_regular_text_not_buffered(self, parser):
result = parser._buffer_delta_text("hello")
assert result == "hello"
assert parser.buffered_delta_text == ""
def test_complete_tag_flushed(self, parser):
parser.buffered_delta_text = "<start_function_"
result = parser._buffer_delta_text("call>")
assert "<start_function_call>" in result
if __name__ == "__main__":
pytest.main([__file__, "-v"])
......@@ -44,6 +44,33 @@ def assert_tool_calls(
)
def run_streaming_sequence(parser, deltas):
"""Helper to simulate a streaming sequence and return results."""
previous_text = ""
previous_token_ids: list[int] = []
results = []
for delta_text, delta_token_ids in deltas:
current_text = previous_text + delta_text
current_token_ids = previous_token_ids + delta_token_ids
result = parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=None,
)
results.append(result)
previous_text = current_text
previous_token_ids = current_token_ids
return results
def test_extract_tool_calls_no_tools(kimi_k2_tool_parser):
model_output = "This is a test"
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
......@@ -346,61 +373,32 @@ def test_token_leak_between_section_and_tool_begin(kimi_k2_tool_parser):
tool_call_begin_token_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
# Simulate streaming sequence:
deltas = [
("I'll help you with that. ", [1, 2, 3]),
("<|tool_calls_section_begin|>", [section_begin_token_id]),
(" spurious text ", [4, 5]),
("<|tool_call_begin|>", [tool_call_begin_token_id]),
]
results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
# Delta 1: "I'll help you with that. "
result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="I'll help you with that. ",
delta_text="I'll help you with that. ",
previous_token_ids=[],
current_token_ids=[1, 2, 3], # Regular tokens
delta_token_ids=[1, 2, 3],
request=None,
)
assert result1 is not None
assert result1.content == "I'll help you with that. "
assert results[0] is not None
assert results[0].content == "I'll help you with that. "
# Delta 2: "<|tool_calls_section_begin|>"
prev_ids = [1, 2, 3]
curr_ids = prev_ids + [section_begin_token_id]
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="I'll help you with that. ",
current_text="I'll help you with that. <|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=prev_ids,
current_token_ids=curr_ids,
delta_token_ids=[section_begin_token_id],
request=None,
)
# Section marker should be stripped and suppressed
assert result2 is None or (result2.content is None or result2.content == "")
assert results[1] is None or (
results[1].content is None or results[1].content == ""
)
# Delta 3: " spurious text or tokens " (THE LEAK SCENARIO)
prev_ids = curr_ids
curr_ids = curr_ids + [4, 5]
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="I'll help you with that. <|tool_calls_section_begin|>",
current_text="I'll help you with that. <|tool_calls_section_begin|> spurious text ",
delta_text=" spurious text ",
previous_token_ids=prev_ids,
current_token_ids=curr_ids,
delta_token_ids=[4, 5],
request=None,
)
# CRITICAL: This text should be suppressed, NOT returned as reasoning_delta
assert result3 is None or (result3.content is None or result3.content == "")
assert results[2] is None or (
results[2].content is None or results[2].content == ""
)
# Delta 4: "<|tool_call_begin|>..."
prev_ids = curr_ids
curr_ids = curr_ids + [tool_call_begin_token_id]
_result4 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="I'll help you with that. <|tool_calls_section_begin|> spurious text ",
current_text="I'll help you with that. <|tool_calls_section_begin|> spurious text <|tool_call_begin|>",
delta_text="<|tool_call_begin|>",
previous_token_ids=prev_ids,
current_token_ids=curr_ids,
delta_token_ids=[tool_call_begin_token_id],
request=None,
)
# Now we're in tool call mode, result depends on internal state
# The key is that the spurious text from Delta 3 was not leaked
......@@ -416,31 +414,15 @@ def test_split_markers_across_deltas(kimi_k2_tool_parser):
"<|tool_calls_section_begin|>"
)
# Delta 1: "...reasoning<|tool_calls_sec"
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Some reasoning",
current_text="Some reasoning<|tool_calls_sec",
delta_text="<|tool_calls_sec",
previous_token_ids=[1, 2],
current_token_ids=[1, 2, 3], # Partial token
delta_token_ids=[3],
request=None,
)
# Partial token not recognized yet, might be buffered
# Should return as content or None (depends on implementation)
# Delta 1: partial token, Delta 2: complete marker
deltas = [
("<|tool_calls_sec", [3]),
("tion_begin|> ", [section_begin_token_id, 4]),
]
_results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
# Delta 2: "tion_begin|> " (completes the marker)
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Some reasoning<|tool_calls_sec",
current_text="Some reasoning<|tool_calls_section_begin|> ",
delta_text="tion_begin|> ",
previous_token_ids=[1, 2, 3],
current_token_ids=[1, 2, section_begin_token_id, 4],
delta_token_ids=[section_begin_token_id, 4],
request=None,
)
# Now the complete marker should be detected via buffer
# The parser should enter tool section mode
assert kimi_k2_tool_parser.in_tool_section is True
......@@ -475,42 +457,17 @@ def test_reentry_to_reasoning_after_tool_section(kimi_k2_tool_parser):
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
# Enter tool section
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="<|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=[],
current_token_ids=[section_begin_id],
delta_token_ids=[section_begin_id],
request=None,
)
assert kimi_k2_tool_parser.in_tool_section is True
deltas = [
("<|tool_calls_section_begin|>", [section_begin_id]),
("<|tool_calls_section_end|>", [section_end_id]),
(" More reasoning", [10, 11]),
]
# Exit tool section
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="<|tool_calls_section_begin|>",
current_text="<|tool_calls_section_begin|><|tool_calls_section_end|>",
delta_text="<|tool_calls_section_end|>",
previous_token_ids=[section_begin_id],
current_token_ids=[section_begin_id, section_end_id],
delta_token_ids=[section_end_id],
request=None,
)
assert kimi_k2_tool_parser.in_tool_section is False
results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
# Subsequent reasoning text should be returned normally
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="<|tool_calls_section_begin|><|tool_calls_section_end|>",
current_text="<|tool_calls_section_begin|><|tool_calls_section_end|> More reasoning",
delta_text=" More reasoning",
previous_token_ids=[section_begin_id, section_end_id],
current_token_ids=[section_begin_id, section_end_id, 10, 11],
delta_token_ids=[10, 11],
request=None,
)
assert result3 is not None
assert result3.content == " More reasoning"
assert kimi_k2_tool_parser.in_tool_section is False
assert results[2] is not None
assert results[2].content == " More reasoning"
def test_empty_tool_section(kimi_k2_tool_parser):
......@@ -819,106 +776,150 @@ def test_tool_call_end_and_section_end_same_chunk(kimi_k2_tool_parser):
tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>")
# Simulate a streaming sequence for a SHORT tool call (all in one chunk):
# 1. Reasoning text
result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="Let me help. ",
delta_text="Let me help. ",
previous_token_ids=[],
current_token_ids=[1, 2],
delta_token_ids=[1, 2],
request=None,
)
assert result1 is not None
assert result1.content == "Let me help. "
# 2. Section begin
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Let me help. ",
current_text="Let me help. <|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=[1, 2],
current_token_ids=[1, 2, section_begin_id],
delta_token_ids=[section_begin_id],
request=None,
)
assert kimi_k2_tool_parser.in_tool_section is True
# 3. Tool call begin + full content + tool_end + section_end ALL IN ONE CHUNK
# This is the critical scenario for short tool calls
combined = (
'<|tool_call_begin|>get_weather:0 <|tool_call_argument_begin|> {"city": "Paris"} '
"<|tool_call_end|><|tool_calls_section_end|>"
)
# Build up the previous text gradually to simulate realistic streaming
prev_text = "Let me help. <|tool_calls_section_begin|>"
curr_text = prev_text + combined
deltas = [
("Let me help. ", [1, 2]),
("<|tool_calls_section_begin|>", [section_begin_id]),
(combined, [tool_begin_id, 10, 11, 12, tool_end_id, section_end_id]),
(" Done", [20]),
]
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text=prev_text,
current_text=curr_text,
delta_text=combined,
previous_token_ids=[1, 2, section_begin_id],
current_token_ids=[
1,
2,
section_begin_id,
tool_begin_id,
10,
11,
12,
tool_end_id,
section_end_id,
],
delta_token_ids=[tool_begin_id, 10, 11, 12, tool_end_id, section_end_id],
request=None,
)
results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
# CRITICAL: Parser should have exited section AFTER processing tool
assert kimi_k2_tool_parser.in_tool_section is False
# Tool call should have been emitted (not dropped)
# The result might be the tool name or None depending on state, but
# importantly, it shouldn't be returning the literal tokens as content
if result3 is not None and result3.content is not None:
if results[2] is not None and results[2].content is not None:
# Verify no special tokens leaked into content
assert "<|tool_call_end|>" not in result3.content
assert "<|tool_calls_section_end|>" not in result3.content
assert "<|tool_call_end|>" not in results[2].content
assert "<|tool_calls_section_end|>" not in results[2].content
# 4. Verify subsequent content streams normally
result4 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text=curr_text,
current_text=curr_text + " Done",
delta_text=" Done",
previous_token_ids=[
1,
2,
section_begin_id,
tool_begin_id,
10,
11,
12,
tool_end_id,
section_end_id,
],
current_token_ids=[
1,
2,
section_begin_id,
tool_begin_id,
10,
11,
12,
tool_end_id,
section_end_id,
20,
],
delta_token_ids=[20],
request=None,
# Content after tool section should stream normally
assert results[3] is not None
assert results[3].content == " Done"
def test_streaming_tool_call_markers_not_leaked(kimi_k2_tool_parser):
"""
CRITICAL TEST: Verify that tool call markers (<|tool_call_begin|>,
<|tool_call_end|>, <|tool_call_argument_begin|>) are NOT leaked
into the content field during streaming.
This reproduces the AWS Bedrock bug where tool call markers appeared
in the 'text' field of responses.
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
tool_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>")
# List of markers that should NEVER appear in content
forbidden_markers = [
"<|tool_call_begin|>",
"<|tool_call_end|>",
"<|tool_call_argument_begin|>",
"<|tool_calls_section_begin|>",
"<|tool_calls_section_end|>",
]
all_content = []
# Steps: reasoning, section begin, tool call, section end, more reasoning
tool_chunk = (
"<|tool_call_begin|> functions.get_weather:0 "
'<|tool_call_argument_begin|> {"city": "Tokyo"} <|tool_call_end|>'
)
deltas = [
("I'll check the weather. ", [1, 2, 3]),
("<|tool_calls_section_begin|>", [section_begin_id]),
(tool_chunk, [tool_begin_id, 10, 11, tool_end_id]),
("<|tool_calls_section_end|>", [section_end_id]),
(" Here's the result.", [20, 21]),
]
results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
for res in results:
if res and res.content:
all_content.append(res.content)
# CRITICAL ASSERTIONS: No forbidden markers in any content
full_content = "".join(all_content)
for marker in forbidden_markers:
assert marker not in full_content, (
f"MARKER LEAK DETECTED: '{marker}' found in content. "
f"Full content: {repr(full_content)}"
)
# Content after tool section should stream normally
assert result4 is not None
assert result4.content == " Done"
# Also check that tool call content (function name, arguments) is not leaked
assert "get_weather" not in full_content, (
f"TOOL CALL CONTENT LEAKED: 'get_weather' found in content. "
f"Full content: {repr(full_content)}"
)
assert "Tokyo" not in full_content, (
f"TOOL CALL CONTENT LEAKED: 'Tokyo' found in content. "
f"Full content: {repr(full_content)}"
)
# Verify that legitimate content was preserved
assert "I'll check the weather." in full_content or len(all_content) > 0
def test_streaming_multiple_tool_calls_not_leaked(kimi_k2_tool_parser):
"""
Test that MULTIPLE tool calls in streaming mode do not leak into content.
This reproduces the AWS Bedrock scenario: "Compare weather in Tokyo and NYC".
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
tool_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>")
all_content = []
tool1 = '<|tool_call_begin|> get_weather:0 <|tool_call_argument_begin|> {"city": "Tokyo"} <|tool_call_end|>'
tool2 = ' <|tool_call_begin|> get_weather:1 <|tool_call_argument_begin|> {"city": "New York"} <|tool_call_end|>'
deltas = [
("I'll compare the weather. ", [1, 2, 3]),
("<|tool_calls_section_begin|>", [section_begin_id]),
(tool1, [tool_begin_id, 10, tool_end_id]),
(tool2, [tool_begin_id, 20, tool_end_id]),
("<|tool_calls_section_end|>", [section_end_id]),
(" Here's the comparison.", [30]),
]
results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
for res in results:
if res and res.content:
all_content.append(res.content)
# Assertions
full_content = "".join(all_content)
# Check no markers leaked
forbidden = ["<|tool_call", "<|tool_calls_section"]
for marker in forbidden:
assert marker not in full_content, (
f"MARKER LEAKED: {marker} in {repr(full_content)}"
)
# Check no tool call content leaked (both tools)
assert "get_weather" not in full_content, f"TOOL NAME LEAKED: {repr(full_content)}"
assert "Tokyo" not in full_content, f"TOOL ARG LEAKED (Tokyo): {repr(full_content)}"
assert "New York" not in full_content, (
f"TOOL ARG LEAKED (NYC): {repr(full_content)}"
)
# Legitimate content preserved
assert "compare" in full_content.lower() or len(all_content) > 0
......@@ -281,6 +281,8 @@ def test_extract_tool_calls_pre_v11_tokenizer(
"single_tool_add",
"single_tool_weather",
"multiple_tool_calls",
"complex",
"wrong_json",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
......@@ -326,6 +328,36 @@ def test_extract_tool_calls_pre_v11_tokenizer(
],
None,
),
(
# Complex
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="bash",
arguments=json.dumps(
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
)[:-2],
)
)
],
"hi{hi",
),
(
# Wrong json
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="bash",
arguments=json.dumps(
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
),
)
)
],
"hi{hi",
),
],
)
def test_extract_tool_calls(
......@@ -673,7 +705,7 @@ def test_extract_tool_calls_streaming(
),
(
# Complex
"""[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
......@@ -684,7 +716,7 @@ def test_extract_tool_calls_streaming(
)
)
],
"",
"hi{hi",
),
],
)
......
......@@ -151,3 +151,45 @@ async def test_chat_completion_with_tools(
assert chunk.choices[0].finish_reason != "tool_calls"
assert len(chunks)
assert "".join(chunks) == output_text
# Regression test for https://github.com/vllm-project/vllm/issues/32006
# Engine crash when combining response_format: json_object with
# tool_choice: required
@pytest.mark.asyncio
@pytest.mark.timeout(120)
async def test_response_format_with_tool_choice_required(
client: openai.AsyncOpenAI, server_config: ServerConfig
):
"""
Test that combining response_format: json_object with tool_choice: required
doesn't crash the engine.
Before the fix, this would cause a validation error:
"You can only use one kind of structured outputs constraint but multiple
are specified" because both json_object and json (from tool schema) would
be set in StructuredOutputsParams.
"""
models = await client.models.list()
model_name: str = models.data[0].id
# This combination previously crashed the engine
chat_completion = await client.chat.completions.create(
messages=ensure_system_prompt(
[{"role": "user", "content": "What is the weather in Dallas, Texas?"}],
server_config,
),
temperature=0,
max_completion_tokens=150,
model=model_name,
tools=[WEATHER_TOOL],
tool_choice="required",
response_format={"type": "json_object"},
)
# The fix clears response_format when tool_choice forces tool calling,
# so the request should complete successfully with tool calls
choice = chat_completion.choices[0]
assert choice.finish_reason == "tool_calls"
assert choice.message.tool_calls is not None
assert len(choice.message.tool_calls) > 0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pytest
from vllm.tool_parsers.minimax_m2_tool_parser import (
MinimaxM2ToolParser,
)
pytestmark = pytest.mark.cpu_test
class FakeTokenizer:
"""Minimal fake tokenizer that exposes the attributes used by the
parser: a truthy model_tokenizer marker and a vocab mapping for the
special tokens.
"""
def __init__(self):
self.model_tokenizer = True
# The parser will look up start/end tokens by their literal strings
self.vocab = {
"<minimax:tool_call>": 1,
"</minimax:tool_call>": 2,
}
def get_vocab(self):
return self.vocab
@pytest.fixture
def minimax_m2_tool_parser():
return MinimaxM2ToolParser(FakeTokenizer())
def test_extract_tool_calls_streaming_incremental(minimax_m2_tool_parser):
parser = minimax_m2_tool_parser
parser._reset_streaming_state()
chunks = [
"<minimax:tool_call>",
'<invoke name="get_weather">',
'<parameter name="city">',
"Seattle</parameter>",
"</invoke></minimax:tool_call>",
]
previous = ""
for chunk in chunks:
current = previous + chunk
delta = chunk
parser.extract_tool_calls_streaming(
previous_text=previous,
current_text=current,
delta_text=delta,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
previous = current
assert len(parser.prev_tool_call_arr) == 1
entry = parser.prev_tool_call_arr[0]
assert entry["name"] == "get_weather"
args = entry["arguments"]
assert args["city"] == "Seattle"
def test_streaming_minimax_m2_multiple_invokes(minimax_m2_tool_parser):
parser = minimax_m2_tool_parser
parser._reset_streaming_state()
chunks = [
"<minimax:tool_call>",
'<invoke name="search_web">',
'<parameter name="query_tag">',
'["technology", "events"]</parameter>',
'<parameter name="query_list">',
'["OpenAI", "latest", "release"]</parameter>',
"</invoke>",
'<invoke name="search_web">',
'<parameter name="query_tag">',
'["technology", "events"]</parameter>',
'<parameter name="query_list">',
'["Gemini", "latest", "release"]</parameter>',
"</invoke>",
"</minimax:tool_call>",
]
previous = ""
for chunk in chunks:
current = previous + chunk
delta = chunk
parser.extract_tool_calls_streaming(
previous_text=previous,
current_text=current,
delta_text=delta,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
previous = current
assert len(parser.prev_tool_call_arr) == 2
for entry, expect_model in zip(parser.prev_tool_call_arr, ["OpenAI", "Gemini"]):
assert entry["name"] == "search_web"
args = json.dumps(entry["arguments"])
assert "technology" in args and "events" in args
assert expect_model in args
# check streamed_args_for_tool for serving_chat.py
for index in range(2):
expected_call = parser.prev_tool_call_arr[index].get("arguments", {})
expected_call = json.dumps(expected_call)
actual_call = parser.streamed_args_for_tool[index]
assert expected_call == actual_call
......@@ -311,6 +311,7 @@ def test_streaming_output_valid(output, empty_params, delta_len):
previous_text = current_text
assert len(messages) > 0
combined_messages = "["
for message in messages:
if message.tool_calls[0].function.name:
......@@ -328,3 +329,35 @@ def test_streaming_output_valid(output, empty_params, delta_len):
combined_messages += "}]"
assert json.loads(combined_messages) == output
assert json.dumps(json.loads(combined_messages)) == output_json
def test_streaming_output_valid_with_trailing_extra_data():
self = MagicMock()
output = [{"name": "get_current_weather", "parameters": {"city": "Vienna"}}]
output_json = json.dumps(output) + "\nDONE"
previous_text = ""
function_name_returned = False
messages = []
delta_len = 3
for i in range(0, len(output_json), delta_len):
delta_text = output_json[i : i + delta_len]
current_text = previous_text + delta_text
delta_message, function_name_returned = (
OpenAIServingChat.extract_tool_call_required_streaming(
self,
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=function_name_returned,
)
)
if delta_message:
messages.append(delta_message)
previous_text = current_text
assert len(messages) > 0
......@@ -112,6 +112,7 @@ class RemoteOpenAIServer:
env.update(env_dict)
serve_cmd = ["vllm", "serve", model, *vllm_serve_args]
print(f"Launching RemoteOpenAIServer with: {' '.join(serve_cmd)}")
print(f"Environment variables: {env}")
self.proc: subprocess.Popen = subprocess.Popen(
serve_cmd,
env=env,
......@@ -726,13 +727,34 @@ def init_test_distributed_environment(
distributed_init_port: str,
local_rank: int = -1,
) -> None:
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment(
world_size=pp_size * tp_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=local_rank,
# Note: This function is often called from Ray worker processes, so we
# can't rely on pytest fixtures to set the config. We check if the config
# is already set and only create a default one if needed.
from vllm.config import (
VllmConfig,
get_current_vllm_config_or_none,
set_current_vllm_config,
)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
if get_current_vllm_config_or_none() is not None:
# Config already set, use it directly
init_distributed_environment(
world_size=pp_size * tp_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=local_rank,
)
else:
# No config set, create a default one for the test
with set_current_vllm_config(VllmConfig()):
init_distributed_environment(
world_size=pp_size * tp_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=local_rank,
)
ensure_model_parallel_initialized(tp_size, pp_size)
......
......@@ -99,30 +99,18 @@ def _test_stream_thread(main_expected_stream: torch.cuda.Stream):
def test_current_stream_multithread():
from vllm.platforms import current_platform
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
if current_platform.is_rocm():
main_dedicated_stream = current_stream()
assert main_dedicated_stream.cuda_stream != 0, (
"ROCm should create a dedicated stream, not use default stream (0x0)"
)
main_stream_again = current_stream()
assert main_stream_again == main_dedicated_stream, (
"Multiple calls to current_stream should return the same dedicated stream"
)
main_dedicated_stream = current_stream()
_test_stream_thread(main_dedicated_stream)
else:
main_default_stream = torch.cuda.default_stream()
main_initial_stream = current_stream()
assert main_dedicated_stream.cuda_stream != 0, (
"ROCm/CUDA should create a dedicated stream, not use default stream (0x0)"
)
assert main_initial_stream == main_default_stream, (
"First call to current_stream should return default stream on CUDA"
)
main_stream_again = current_stream()
assert main_stream_again == main_dedicated_stream, (
"Multiple calls to current_stream should return the same dedicated stream"
)
_test_stream_thread(main_default_stream)
_test_stream_thread(main_dedicated_stream)
......@@ -15,13 +15,17 @@ from tests.v1.attention.utils import (
create_vllm_config,
try_get_attention_backend,
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ModelConfig
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, is_torch_equal_or_newer
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
is_torch_equal_or_newer,
set_random_seed,
)
from vllm.v1.attention.backend import AttentionType, CommonAttentionMetadata
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
set_kv_cache_layout,
)
from vllm.v1.kv_cache_interface import FullAttentionSpec
......@@ -79,6 +83,13 @@ BATCH_SPECS = {
),
"single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]),
"single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]),
# encoder-only
"small_encoder_prefill": BatchSpec(
seq_lens=[32, 64, 128, 256], query_lens=[32, 64, 128, 256]
),
"medium_encoder_prefill": BatchSpec(
seq_lens=[256, 512, 1024, 2048], query_lens=[256, 512, 1024, 2048]
),
}
......@@ -114,17 +125,17 @@ def create_and_prepopulate_kv_cache(
Tuple of (kv_cache, updated_block_table)
"""
batch_size = len(k_contexts)
seq_lens = common_attn_metadata.seq_lens_cpu
seq_lens = common_attn_metadata.seq_lens.cpu()
query_lens = (
common_attn_metadata.query_start_loc_cpu[1:]
- common_attn_metadata.query_start_loc_cpu[:-1]
)
context_lens = common_attn_metadata.num_computed_tokens_cpu
context_lens = seq_lens - query_lens
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
# Create KV cache
kv_cache = torch.empty(
kv_cache = torch.zeros(
2, num_blocks, block_size, num_kv_heads, head_size, dtype=dtype, device=device
)
kv_cache_flat = kv_cache.view(2, -1, num_kv_heads, head_size)
......@@ -205,6 +216,7 @@ def run_attention_backend(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: AttentionType = AttentionType.DECODER,
sliding_window: int | None = None,
) -> torch.Tensor:
"""Run attention computation using the specified backend's AttentionImpl."""
......@@ -272,6 +284,7 @@ def run_attention_backend(
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=sliding_window,
attn_type=attn_type,
kv_cache_dtype="auto",
)
......@@ -295,6 +308,7 @@ def _test_backend_correctness(
backend_to_test: list[AttentionBackendEnum | str],
mask_mod,
*,
attn_type: AttentionType = AttentionType.DECODER,
block_size: int = 16,
atol: float = 1e-2,
rtol: float = 1e-2,
......@@ -320,7 +334,7 @@ def _test_backend_correctness(
multiple GPUs. This tests that backends work correctly with different
head counts.
"""
current_platform.seed_everything(42)
set_random_seed(42)
hf_config_override = None
if tensor_parallel_size > 1:
......@@ -432,6 +446,9 @@ def _test_backend_correctness(
common_attn_metadata = create_common_attn_metadata(
batch_spec, vllm_config.cache_config.block_size, device
)
if attn_type == AttentionType.ENCODER_ONLY:
# For encoder-only, all tokens are prefill tokens
common_attn_metadata.causal = False
# 3. Simulate Paged KV Cache and a realistic slot_mapping
kv_cache = create_and_prepopulate_kv_cache(
......@@ -487,6 +504,7 @@ def _test_backend_correctness(
value_vllm,
kv_cache_for_backend,
sliding_window=sliding_window,
attn_type=attn_type,
)
finally:
if reset_kv_cache_layout:
......@@ -537,7 +555,7 @@ def _test_backend_correctness(
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
def test_causal_backend_correctness(
batch_spec_name: str, model: str, tensor_parallel_size: int
default_vllm_config, batch_spec_name: str, model: str, tensor_parallel_size: int
):
"""Test backend's correctness with causal attention."""
......@@ -557,9 +575,21 @@ def test_causal_backend_correctness(
if is_torch_equal_or_newer("2.9.0.dev0")
else []
)
SMALL_BLOCK_BACKENDS = [
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
]
if current_platform.is_rocm():
SMALL_BLOCK_BACKENDS = [
x
for x in BACKENDS_TO_TEST
if (
x not in LARGE_BLOCK_BACKENDS
and x is not AttentionBackendEnum.FLASH_ATTN
)
]
else:
SMALL_BLOCK_BACKENDS = [
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
]
_test_backend_correctness(
batch_spec,
model,
......@@ -580,12 +610,20 @@ def test_causal_backend_correctness(
)
SLIDING_WINDOW_BACKENDS_TO_TEST = [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.FLEX_ATTENTION,
AttentionBackendEnum.TRITON_ATTN,
"FLEX_ATTENTION_SLOW",
]
if current_platform.is_rocm():
# FLASH_ATTN is not supported on ROCm
SLIDING_WINDOW_BACKENDS_TO_TEST = [
AttentionBackendEnum.FLEX_ATTENTION,
AttentionBackendEnum.TRITON_ATTN,
"FLEX_ATTENTION_SLOW",
]
else:
SLIDING_WINDOW_BACKENDS_TO_TEST = [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.FLEX_ATTENTION,
AttentionBackendEnum.TRITON_ATTN,
"FLEX_ATTENTION_SLOW",
]
@pytest.mark.parametrize(
......@@ -652,3 +690,45 @@ def test_sliding_window_backend_correctness(
block_size=128,
tensor_parallel_size=tensor_parallel_size,
)
@pytest.mark.parametrize(
"batch_spec_name",
[
"small_encoder_prefill",
"medium_encoder_prefill",
],
)
@pytest.mark.parametrize("model", ["google/embeddinggemma-300m"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
def test_sliding_window_encoder_backend_correctness(
batch_spec_name: str, model: str, tensor_parallel_size: int
):
"""Test backend's correctness with sliding window attention."""
def bidi_sliding_window_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
*,
context_len: int,
sliding_window: int,
):
return torch.abs(q_idx + context_len - kv_idx) < sliding_window
batch_spec = BATCH_SPECS[batch_spec_name]
model_config = ModelConfig(model=model, max_model_len=max(batch_spec.seq_lens))
sliding_window = model_config.get_sliding_window()
sliding_window_mask_mod_fn = partial(
bidi_sliding_window_mask_mod, sliding_window=sliding_window
)
_test_backend_correctness(
batch_spec,
model,
SLIDING_WINDOW_BACKENDS_TO_TEST,
sliding_window_mask_mod_fn,
attn_type=AttentionType.ENCODER_ONLY,
tensor_parallel_size=tensor_parallel_size,
)
......@@ -79,7 +79,12 @@ from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
],
)
def test_mamba_layers_get_attn_backend(
dist_init, layer_class, init_kwargs, expected_backend, expected_mamba_type
default_vllm_config,
dist_init,
layer_class,
init_kwargs,
expected_backend,
expected_mamba_type,
):
"""Test that Mamba-like layers return the correct attention backend."""
layer = layer_class(**init_kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment