Commit a810671a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0rc0' into v0.14.0rc0-ori

parents 86b5aefe 6a09612b
...@@ -75,7 +75,6 @@ def test_models( ...@@ -75,7 +75,6 @@ def test_models(
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("TOKENIZERS_PARALLELISM", "true") m.setenv("TOKENIZERS_PARALLELISM", "true")
m.setenv("VLLM_ATTENTION_BACKEND", backend)
MAX_MODEL_LEN = 1024 MAX_MODEL_LEN = 1024
NUM_LOG_PROBS = 8 NUM_LOG_PROBS = 8
...@@ -86,6 +85,7 @@ def test_models( ...@@ -86,6 +85,7 @@ def test_models(
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
kv_cache_dtype="auto", kv_cache_dtype="auto",
attention_config={"backend": backend},
) as vllm_model: ) as vllm_model:
baseline_outputs = vllm_model.generate_greedy_logprobs( baseline_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS example_prompts, max_tokens, NUM_LOG_PROBS
...@@ -97,6 +97,7 @@ def test_models( ...@@ -97,6 +97,7 @@ def test_models(
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
attention_config={"backend": backend},
) as vllm_model: ) as vllm_model:
test_outputs = vllm_model.generate_greedy_logprobs( test_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS example_prompts, max_tokens, NUM_LOG_PROBS
......
...@@ -107,11 +107,12 @@ def can_initialize( ...@@ -107,11 +107,12 @@ def can_initialize(
patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1), patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1),
monkeypatch.context() as m, monkeypatch.context() as m,
): ):
if model_arch == "GptOssForCausalLM": # FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU
# FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU # has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when # L4 supports FA3.
# L4 supports FA3. attention_config = (
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") {"backend": "TRITON_ATTN"} if model_arch == "GptOssForCausalLM" else None
)
if model_arch == "WhisperForConditionalGeneration": if model_arch == "WhisperForConditionalGeneration":
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
...@@ -142,6 +143,7 @@ def can_initialize( ...@@ -142,6 +143,7 @@ def can_initialize(
else "vllm", else "vllm",
hf_overrides=hf_overrides_fn, hf_overrides=hf_overrides_fn,
max_num_seqs=model_info.max_num_seqs, max_num_seqs=model_info.max_num_seqs,
attention_config=attention_config,
) )
......
...@@ -5,9 +5,6 @@ ...@@ -5,9 +5,6 @@
# The utility function cannot be placed in `vllm.utils` # The utility function cannot be placed in `vllm.utils`
# this needs to be a standalone script # this needs to be a standalone script
import sys import sys
from contextlib import nullcontext
from vllm_test_utils import BlameResult, blame
# List of modules that should not be imported too early. # List of modules that should not be imported too early.
# Lazy import `torch._inductor.async_compile` to avoid creating # Lazy import `torch._inductor.async_compile` to avoid creating
...@@ -16,26 +13,10 @@ from vllm_test_utils import BlameResult, blame ...@@ -16,26 +13,10 @@ from vllm_test_utils import BlameResult, blame
# `cv2` can easily mess up the environment. # `cv2` can easily mess up the environment.
module_names = ["torch._inductor.async_compile", "cv2"] module_names = ["torch._inductor.async_compile", "cv2"]
# set all modules in `module_names` to be None.
# if we import any modules during `import vllm`, there would be a
# hard error and nice stacktrace on the first import.
for module_name in module_names:
sys.modules[module_name] = None # type: ignore[assignment]
def any_module_imported(): import vllm # noqa
return any(module_name in sys.modules for module_name in module_names)
# In CI, we only check finally if the module is imported.
# If it is indeed imported, we can rerun the test with `use_blame=True`,
# which will trace every function call to find the first import location,
# and help find the root cause.
# We don't run it in CI by default because it is slow.
use_blame = False
context = blame(any_module_imported) if use_blame else nullcontext()
with context as result:
import vllm # noqa
if use_blame:
assert isinstance(result, BlameResult)
print(f"the first import location is:\n{result.trace_stack}")
assert not any_module_imported(), (
f"Some the modules in {module_names} are imported. To see the first"
f" import location, run the test with `use_blame=True`."
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pytest
from vllm.tool_parsers.minimax_m2_tool_parser import (
MinimaxM2ToolParser,
)
pytestmark = pytest.mark.cpu_test
class FakeTokenizer:
"""Minimal fake tokenizer that exposes the attributes used by the
parser: a truthy model_tokenizer marker and a vocab mapping for the
special tokens.
"""
def __init__(self):
self.model_tokenizer = True
# The parser will look up start/end tokens by their literal strings
self.vocab = {
"<minimax:tool_call>": 1,
"</minimax:tool_call>": 2,
}
def get_vocab(self):
return self.vocab
@pytest.fixture
def minimax_m2_tool_parser():
return MinimaxM2ToolParser(FakeTokenizer())
def test_extract_tool_calls_streaming_incremental(minimax_m2_tool_parser):
parser = minimax_m2_tool_parser
parser._reset_streaming_state()
chunks = [
"<minimax:tool_call>",
'<invoke name="get_weather">',
'<parameter name="city">',
"Seattle</parameter>",
"</invoke></minimax:tool_call>",
]
previous = ""
for chunk in chunks:
current = previous + chunk
delta = chunk
parser.extract_tool_calls_streaming(
previous_text=previous,
current_text=current,
delta_text=delta,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
previous = current
assert len(parser.prev_tool_call_arr) == 1
entry = parser.prev_tool_call_arr[0]
assert entry["name"] == "get_weather"
args = entry["arguments"]
assert args["city"] == "Seattle"
def test_streaming_minimax_m2_multiple_invokes(minimax_m2_tool_parser):
parser = minimax_m2_tool_parser
parser._reset_streaming_state()
chunks = [
"<minimax:tool_call>",
'<invoke name="search_web">',
'<parameter name="query_tag">',
'["technology", "events"]</parameter>',
'<parameter name="query_list">',
'["OpenAI", "latest", "release"]</parameter>',
"</invoke>",
'<invoke name="search_web">',
'<parameter name="query_tag">',
'["technology", "events"]</parameter>',
'<parameter name="query_list">',
'["Gemini", "latest", "release"]</parameter>',
"</invoke>",
"</minimax:tool_call>",
]
previous = ""
for chunk in chunks:
current = previous + chunk
delta = chunk
parser.extract_tool_calls_streaming(
previous_text=previous,
current_text=current,
delta_text=delta,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
previous = current
assert len(parser.prev_tool_call_arr) == 2
for entry, expect_model in zip(parser.prev_tool_call_arr, ["OpenAI", "Gemini"]):
assert entry["name"] == "search_web"
args = json.dumps(entry["arguments"])
assert "technology" in args and "events" in args
assert expect_model in args
# check streamed_args_for_tool for serving_chat.py
for index in range(2):
expected_call = parser.prev_tool_call_arr[index].get("arguments", {})
expected_call = json.dumps(expected_call)
actual_call = parser.streamed_args_for_tool[index]
assert expected_call == actual_call
...@@ -323,6 +323,7 @@ def test_prefill_split_across_ubatches( ...@@ -323,6 +323,7 @@ def test_prefill_split_across_ubatches(
num_tokens, num_tokens,
batch_spec.batch_size, batch_spec.batch_size,
split_point=split_point, split_point=split_point,
num_ubatches=2,
) )
assert ubatch_slices is not None and len(ubatch_slices) == 2 assert ubatch_slices is not None and len(ubatch_slices) == 2
......
...@@ -172,7 +172,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData): ...@@ -172,7 +172,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):
) )
# Call the function # Call the function
result = make_local_attention_virtual_batches( result, _ = make_local_attention_virtual_batches(
attn_chunk_size, common_attn_metadata, block_size attn_chunk_size, common_attn_metadata, block_size
) )
......
...@@ -94,26 +94,20 @@ def mock_on_gfx9(): ...@@ -94,26 +94,20 @@ def mock_on_gfx9():
None, None,
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(), AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(),
), ),
# Test Case 9: VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1 # Test Case 9: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN
(
{"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"},
None,
AttentionBackendEnum.ROCM_ATTN.get_path(),
),
# Test Case 10: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN
( (
{"VLLM_ROCM_USE_AITER": "1"}, {"VLLM_ROCM_USE_AITER": "1"},
"TRITON_ATTN", "TRITON_ATTN",
AttentionBackendEnum.TRITON_ATTN.get_path(), AttentionBackendEnum.TRITON_ATTN.get_path(),
), ),
# Test Case 11: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0 # Test Case 10: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0
# (explicitly disabled) # (explicitly disabled)
( (
{"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "0"}, {"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "0"},
None, None,
AttentionBackendEnum.TRITON_ATTN.get_path(), AttentionBackendEnum.TRITON_ATTN.get_path(),
), ),
# Test Case 12: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN # Test Case 11: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN
( (
{"VLLM_ROCM_USE_AITER": "1"}, {"VLLM_ROCM_USE_AITER": "1"},
"ROCM_ATTN", "ROCM_ATTN",
......
...@@ -249,8 +249,8 @@ def create_dummy_kv_cache( ...@@ -249,8 +249,8 @@ def create_dummy_kv_cache(
@dataclass @dataclass
class BackendConfig: class BackendConfig:
name: str name: str
env_vars: dict attention_config: dict
comp_config: dict # compilation config comp_config: dict
specific_gpu_arch: tuple | None = None specific_gpu_arch: tuple | None = None
...@@ -259,10 +259,10 @@ full_cg_backend_configs = { ...@@ -259,10 +259,10 @@ full_cg_backend_configs = {
# FA3 on Hopper # FA3 on Hopper
"FA3": BackendConfig( "FA3": BackendConfig(
name="FA3", name="FA3",
env_vars={ attention_config={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN", "backend": "FLASH_ATTN",
"VLLM_FLASH_ATTN_VERSION": "3", "flash_attn_version": 3,
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", "flash_attn_max_num_splits_for_cuda_graph": 16,
}, },
comp_config={ comp_config={
"cudagraph_mode": "FULL", "cudagraph_mode": "FULL",
...@@ -272,9 +272,7 @@ full_cg_backend_configs = { ...@@ -272,9 +272,7 @@ full_cg_backend_configs = {
# FlashMLA on Hopper # FlashMLA on Hopper
"FlashMLA": BackendConfig( "FlashMLA": BackendConfig(
name="FlashMLA", name="FlashMLA",
env_vars={ attention_config={"backend": "FLASHMLA"},
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
},
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}, },
...@@ -283,9 +281,7 @@ full_cg_backend_configs = { ...@@ -283,9 +281,7 @@ full_cg_backend_configs = {
# Cutlass MLA on Blackwell # Cutlass MLA on Blackwell
"CutlassMLA": BackendConfig( "CutlassMLA": BackendConfig(
name="CutlassMLA", name="CutlassMLA",
env_vars={ attention_config={"backend": "CUTLASS_MLA"},
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
},
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}, },
...@@ -294,9 +290,7 @@ full_cg_backend_configs = { ...@@ -294,9 +290,7 @@ full_cg_backend_configs = {
# FlashInfer MLA on Blackwell # FlashInfer MLA on Blackwell
"FlashInferMLA": BackendConfig( "FlashInferMLA": BackendConfig(
name="FlashInferMLA", name="FlashInferMLA",
env_vars={ attention_config={"backend": "FLASHINFER_MLA"},
"VLLM_ATTENTION_BACKEND": "FLASHINFER_MLA",
},
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}, },
...@@ -305,9 +299,9 @@ full_cg_backend_configs = { ...@@ -305,9 +299,9 @@ full_cg_backend_configs = {
# FlashAttention MLA on Hopper # FlashAttention MLA on Hopper
"FlashAttentionMLA": BackendConfig( "FlashAttentionMLA": BackendConfig(
name="FlashAttentionMLA", name="FlashAttentionMLA",
env_vars={ attention_config={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", "backend": "FLASH_ATTN_MLA",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", "flash_attn_max_num_splits_for_cuda_graph": 16,
}, },
comp_config={ comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY", "cudagraph_mode": "FULL_DECODE_ONLY",
...@@ -317,10 +311,10 @@ full_cg_backend_configs = { ...@@ -317,10 +311,10 @@ full_cg_backend_configs = {
# FA2 # FA2
"FA2": BackendConfig( "FA2": BackendConfig(
name="FA2", name="FA2",
env_vars={ attention_config={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN", "backend": "FLASH_ATTN",
"VLLM_FLASH_ATTN_VERSION": "2", "flash_attn_version": 2,
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", "flash_attn_max_num_splits_for_cuda_graph": 16,
}, },
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
...@@ -329,7 +323,7 @@ full_cg_backend_configs = { ...@@ -329,7 +323,7 @@ full_cg_backend_configs = {
# Triton Attention # Triton Attention
"TritonAttn": BackendConfig( "TritonAttn": BackendConfig(
name="TritonAttn", name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"}, attention_config={"backend": "TRITON_ATTN"},
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}, },
...@@ -337,14 +331,17 @@ full_cg_backend_configs = { ...@@ -337,14 +331,17 @@ full_cg_backend_configs = {
# FlashInfer # FlashInfer
"FlashInfer": BackendConfig( "FlashInfer": BackendConfig(
name="FlashInfer", name="FlashInfer",
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"}, attention_config={"backend": "FLASHINFER"},
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}, },
), ),
"RocmAttn": BackendConfig( "RocmAttn": BackendConfig(
name="RocmAttn", name="RocmAttn",
env_vars={"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"}, attention_config={
"backend": "ROCM_ATTN",
"use_prefill_decode_attention": True,
},
comp_config={ comp_config={
"cudagraph_mode": "FULL", "cudagraph_mode": "FULL",
}, },
......
...@@ -49,7 +49,10 @@ def _create_vllm_config( ...@@ -49,7 +49,10 @@ def _create_vllm_config(
mock_config.lora_config = None mock_config.lora_config = None
# Mimic the behavior of VllmConfig.__post_init__() # Mimic the behavior of VllmConfig.__post_init__()
if compilation_config.mode == CompilationMode.VLLM_COMPILE: if compilation_config.mode == CompilationMode.VLLM_COMPILE:
compilation_config.set_splitting_ops_for_v1() compilation_config.set_splitting_ops_for_v1(
all2all_backend=mock_config.parallel_config.all2all_backend,
data_parallel_size=mock_config.parallel_config.data_parallel_size,
)
# mimic VllmConfig.__post_init__ # mimic VllmConfig.__post_init__
if compilation_config.cudagraph_capture_sizes: if compilation_config.cudagraph_capture_sizes:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import weakref import weakref
from contextlib import ExitStack from contextlib import ExitStack
...@@ -13,26 +11,6 @@ from vllm import LLM ...@@ -13,26 +11,6 @@ from vllm import LLM
from vllm.config import CompilationConfig, CompilationMode from vllm.config import CompilationConfig, CompilationMode
from vllm.platforms import current_platform from vllm.platforms import current_platform
@contextlib.contextmanager
def temporary_environ(env_vars):
"""
Temporarily set environment variables and restore them afterward.
We have to do this vs monkeypatch because monkeypatch doesn't work
with "module" scoped fixtures.
"""
original_env = {k: os.environ.get(k) for k in env_vars}
try:
os.environ.update(env_vars)
yield
finally:
for k, v in original_env.items():
if v is None:
os.environ.pop(k, None)
else:
os.environ[k] = v
# test attention backend and cudagraph_mode combo # test attention backend and cudagraph_mode combo
# (backend_name, cudagraph_mode, supported) # (backend_name, cudagraph_mode, supported)
if current_platform.is_rocm(): if current_platform.is_rocm():
...@@ -68,9 +46,9 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte ...@@ -68,9 +46,9 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
): ):
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA") pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
env_vars = backend_configs[backend_name].env_vars attention_config = backend_config.attention_config
with temporary_environ(env_vars), ExitStack() as stack: with ExitStack() as stack:
if not supported: if not supported:
stack.enter_context(pytest.raises(Exception)) stack.enter_context(pytest.raises(Exception))
...@@ -80,6 +58,7 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte ...@@ -80,6 +58,7 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
trust_remote_code=True, trust_remote_code=True,
gpu_memory_utilization=0.45, gpu_memory_utilization=0.45,
max_model_len=1024, max_model_len=1024,
attention_config=attention_config,
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE, cudagraph_mode=cudagraph_mode mode=CompilationMode.VLLM_COMPILE, cudagraph_mode=cudagraph_mode
), ),
...@@ -122,9 +101,10 @@ combo_cases_2 = [ ...@@ -122,9 +101,10 @@ combo_cases_2 = [
def test_cudagraph_compilation_combo( def test_cudagraph_compilation_combo(
backend_name, cudagraph_mode, compilation_mode, supported backend_name, cudagraph_mode, compilation_mode, supported
): ):
env_vars = backend_configs[backend_name].env_vars backend_config = backend_configs[backend_name]
attention_config = backend_config.attention_config
with temporary_environ(env_vars), ExitStack() as stack: with ExitStack() as stack:
if not supported: if not supported:
stack.enter_context(pytest.raises(Exception)) stack.enter_context(pytest.raises(Exception))
...@@ -134,6 +114,7 @@ def test_cudagraph_compilation_combo( ...@@ -134,6 +114,7 @@ def test_cudagraph_compilation_combo(
trust_remote_code=True, trust_remote_code=True,
gpu_memory_utilization=0.45, gpu_memory_utilization=0.45,
max_model_len=1024, max_model_len=1024,
attention_config=attention_config,
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
mode=compilation_mode, cudagraph_mode=cudagraph_mode mode=compilation_mode, cudagraph_mode=cudagraph_mode
), ),
......
...@@ -28,7 +28,7 @@ IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90() ...@@ -28,7 +28,7 @@ IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90()
BACKENDS, BACKENDS,
) )
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
backend, monkeypatch: pytest.MonkeyPatch backend,
): ):
""" """
Ensures that the same request (the 'needle' prompt) yields identical output Ensures that the same request (the 'needle' prompt) yields identical output
...@@ -54,7 +54,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( ...@@ -54,7 +54,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend) attention_config = {"backend": backend}
# Allow overrides from environment (useful for CI tuning) # Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism # "facebook/opt-125m" is too small, doesn't reliably test determinism
model = resolve_model_name(backend) model = resolve_model_name(backend)
...@@ -92,6 +92,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( ...@@ -92,6 +92,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
max_num_seqs=max_batch_size, max_num_seqs=max_batch_size,
gpu_memory_utilization=gpu_mem_util, gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len, max_model_len=max_model_len,
attention_config=attention_config,
) )
# Baseline generation for the needle prompt alone. # Baseline generation for the needle prompt alone.
...@@ -106,6 +107,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( ...@@ -106,6 +107,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
max_num_seqs=max_batch_size, max_num_seqs=max_batch_size,
gpu_memory_utilization=gpu_mem_util, gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len, max_model_len=max_model_len,
attention_config=attention_config,
) )
mismatches = 0 mismatches = 0
...@@ -163,10 +165,8 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( ...@@ -163,10 +165,8 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
BACKENDS, BACKENDS,
) )
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
backend, monkeypatch: pytest.MonkeyPatch backend,
): ):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
model_name = resolve_model_name(backend) model_name = resolve_model_name(backend)
...@@ -188,12 +188,12 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( ...@@ -188,12 +188,12 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
llm = LLM( llm = LLM(
model=model_name, model=model_name,
tensor_parallel_size=tp_size, tensor_parallel_size=tp_size,
# enable_prefix_caching=False,
max_num_seqs=32, max_num_seqs=32,
max_model_len=8192, max_model_len=8192,
dtype="bfloat16", # not everything is supported dtype="bfloat16", # not everything is supported
gpu_memory_utilization=0.9, gpu_memory_utilization=0.9,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config={"backend": backend},
) )
# Use more realistic prompts for better token generation # Use more realistic prompts for better token generation
...@@ -382,12 +382,11 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( ...@@ -382,12 +382,11 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
"backend", "backend",
BACKENDS, BACKENDS,
) )
def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch): def test_simple_generation(backend):
""" """
Simple test that runs the model with a basic prompt and prints the output. Simple test that runs the model with a basic prompt and prints the output.
Useful for quick smoke testing and debugging. Useful for quick smoke testing and debugging.
""" """
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
model = resolve_model_name(backend) model = resolve_model_name(backend)
llm = LLM( llm = LLM(
...@@ -399,6 +398,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch): ...@@ -399,6 +398,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
dtype="bfloat16", dtype="bfloat16",
enable_prefix_caching=False, enable_prefix_caching=False,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config={"backend": backend},
) )
prompt = "the capital of france is" prompt = "the capital of france is"
...@@ -445,8 +445,6 @@ def test_logprobs_without_batch_invariance_should_fail( ...@@ -445,8 +445,6 @@ def test_logprobs_without_batch_invariance_should_fail(
The test will PASS if we detect differences (proving batch invariance matters). The test will PASS if we detect differences (proving batch invariance matters).
The test will FAIL if everything matches (suggesting batch invariance isn't needed). The test will FAIL if everything matches (suggesting batch invariance isn't needed).
""" """
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
# CRITICAL: Disable batch invariance for this test # CRITICAL: Disable batch invariance for this test
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0") monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False) monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False)
...@@ -466,6 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail( ...@@ -466,6 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail(
max_model_len=8192, max_model_len=8192,
dtype="bfloat16", dtype="bfloat16",
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config={"backend": backend},
) )
# build ragged prompts to change shapes significantly across BS=1 vs BS=N # build ragged prompts to change shapes significantly across BS=1 vs BS=N
...@@ -650,7 +649,7 @@ def test_logprobs_without_batch_invariance_should_fail( ...@@ -650,7 +649,7 @@ def test_logprobs_without_batch_invariance_should_fail(
@skip_unsupported @skip_unsupported
@pytest.mark.parametrize("backend", ["FLASH_ATTN"]) @pytest.mark.parametrize("backend", ["FLASH_ATTN"])
def test_decode_logprobs_match_prefill_logprobs( def test_decode_logprobs_match_prefill_logprobs(
backend, monkeypatch: pytest.MonkeyPatch backend,
): ):
""" """
Test that verifies decode logprobs match prefill logprobs. Test that verifies decode logprobs match prefill logprobs.
...@@ -665,8 +664,6 @@ def test_decode_logprobs_match_prefill_logprobs( ...@@ -665,8 +664,6 @@ def test_decode_logprobs_match_prefill_logprobs(
This ensures that the logprobs from decode are consistent with what This ensures that the logprobs from decode are consistent with what
we would get if we ran prefill on each prefix. we would get if we ran prefill on each prefix.
""" """
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed) random.seed(seed)
model_name = resolve_model_name(backend) model_name = resolve_model_name(backend)
...@@ -690,6 +687,7 @@ def test_decode_logprobs_match_prefill_logprobs( ...@@ -690,6 +687,7 @@ def test_decode_logprobs_match_prefill_logprobs(
max_model_len=8192, max_model_len=8192,
dtype="bfloat16", dtype="bfloat16",
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config={"backend": backend},
) )
# Use a few test prompts # Use a few test prompts
...@@ -921,6 +919,7 @@ def LLM_with_max_seqs( ...@@ -921,6 +919,7 @@ def LLM_with_max_seqs(
max_num_seqs: int, max_num_seqs: int,
gpu_memory_utilization: float, gpu_memory_utilization: float,
max_model_len: int, max_model_len: int,
attention_config: dict | None = None,
) -> LLM: ) -> LLM:
""" """
Helper to construct an LLM with a specific max_num_seqs (batch-size limit) Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
...@@ -935,6 +934,7 @@ def LLM_with_max_seqs( ...@@ -935,6 +934,7 @@ def LLM_with_max_seqs(
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enable_prefix_caching=False, enable_prefix_caching=False,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
attention_config=attention_config,
# Enable for MOE models # Enable for MOE models
# enable_expert_parallel=True, # enable_expert_parallel=True,
) )
...@@ -136,11 +136,9 @@ def _compare_bs1_vs_bsn_single_process( ...@@ -136,11 +136,9 @@ def _compare_bs1_vs_bsn_single_process(
@skip_unsupported @skip_unsupported
@pytest.mark.parametrize("backend", BACKENDS) @pytest.mark.parametrize("backend", BACKENDS)
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
backend: str, monkeypatch: pytest.MonkeyPatch backend: str,
) -> None: ) -> None:
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345"))) random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
# Override backend for this test (and the RemoteOpenAIServer child process).
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
model_name = resolve_model_name(backend) model_name = resolve_model_name(backend)
prompts_all = [_random_prompt(10, 50) for _ in range(32)] prompts_all = [_random_prompt(10, 50) for _ in range(32)]
...@@ -156,6 +154,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( ...@@ -156,6 +154,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
server_args: list[str] = [ server_args: list[str] = [
"--max-model-len=8192", "--max-model-len=8192",
"--max-num-seqs=32", "--max-num-seqs=32",
f"--attention-backend={backend}",
] ]
if tp_size: if tp_size:
server_args += ["-tp", tp_size] server_args += ["-tp", tp_size]
......
...@@ -142,16 +142,17 @@ def run_tests( ...@@ -142,16 +142,17 @@ def run_tests(
"""Test consistency of combos of async scheduling, preemption, """Test consistency of combos of async scheduling, preemption,
uni/multiproc executor with spec decoding.""" uni/multiproc executor with spec decoding."""
with monkeypatch.context() as m: # Determine attention config based on platform
# avoid precision errors if current_platform.is_rocm():
if current_platform.is_rocm(): if is_testing_with_spec_decoding:
if is_testing_with_spec_decoding: # Use TRITON_ATTN for spec decoding test for consistency
# Use TRITON_ATTN for spec decoding test for consistency attention_config = {"backend": "TRITON_ATTN"}
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
else:
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_FA")
else: else:
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") attention_config = {"backend": "ROCM_ATTN"}
else:
attention_config = {"backend": "FLEX_ATTENTION"}
with monkeypatch.context() as m:
# lock matmul precision to full FP32 (IEEE) # lock matmul precision to full FP32 (IEEE)
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee") m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee")
# m.setenv("VLLM_BATCH_INVARIANT", "1") # m.setenv("VLLM_BATCH_INVARIANT", "1")
...@@ -174,6 +175,7 @@ def run_tests( ...@@ -174,6 +175,7 @@ def run_tests(
spec_config, spec_config,
test_prefill_chunking=test_prefill_chunking, test_prefill_chunking=test_prefill_chunking,
is_testing_with_spec_decoding=is_testing_with_spec_decoding, is_testing_with_spec_decoding=is_testing_with_spec_decoding,
attention_config=attention_config,
) )
outputs.append(test_results) outputs.append(test_results)
...@@ -262,6 +264,7 @@ def run_test( ...@@ -262,6 +264,7 @@ def run_test(
spec_config: dict[str, Any] | None, spec_config: dict[str, Any] | None,
test_prefill_chunking: bool, test_prefill_chunking: bool,
is_testing_with_spec_decoding: bool = False, is_testing_with_spec_decoding: bool = False,
attention_config: dict[str, Any] | None = None,
): ):
spec_decoding = spec_config is not None spec_decoding = spec_config is not None
cache_arg: dict[str, Any] = ( cache_arg: dict[str, Any] = (
...@@ -281,14 +284,6 @@ def run_test( ...@@ -281,14 +284,6 @@ def run_test(
print(f"---- TESTING {test_str}: {test_config}") print(f"---- TESTING {test_str}: {test_config}")
print("-" * 80) print("-" * 80)
# On ROCm: use float16 for first test (ROCM_AITER_FA), but float32 for
# spec decoding test (TRITON_ATTN) for better precision.
# On others: always use float32.
if current_platform.is_rocm() and not is_testing_with_spec_decoding:
dtype = "float16"
else:
dtype = "float32"
with VllmRunner( with VllmRunner(
model, model,
max_model_len=512, max_model_len=512,
...@@ -298,9 +293,10 @@ def run_test( ...@@ -298,9 +293,10 @@ def run_test(
# enforce_eager=True, # enforce_eager=True,
async_scheduling=async_scheduling, async_scheduling=async_scheduling,
distributed_executor_backend=executor, distributed_executor_backend=executor,
dtype=dtype, dtype="float32",
speculative_config=spec_config, speculative_config=spec_config,
disable_log_stats=False, disable_log_stats=False,
attention_config=attention_config,
**cache_arg, **cache_arg,
) as vllm_model: ) as vllm_model:
results = [] results = []
......
...@@ -10,7 +10,7 @@ from ...utils import create_new_process_for_each_test ...@@ -10,7 +10,7 @@ from ...utils import create_new_process_for_each_test
@create_new_process_for_each_test() @create_new_process_for_each_test()
@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "FLASHINFER"]) @pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "FLASHINFER"])
def test_cascade_attention(example_system_message, monkeypatch, attn_backend): def test_cascade_attention(example_system_message, attn_backend):
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:" prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
if attn_backend == "FLASHINFER": if attn_backend == "FLASHINFER":
...@@ -19,19 +19,18 @@ def test_cascade_attention(example_system_message, monkeypatch, attn_backend): ...@@ -19,19 +19,18 @@ def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
"needs investigation. See issue #25679." "needs investigation. See issue #25679."
) )
with monkeypatch.context() as m: llm = LLM(
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) model="Qwen/Qwen2-1.5B-Instruct", attention_config={"backend": attn_backend}
)
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct") sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
# No cascade attention.
# No cascade attention. single_prompt = [example_system_message + prompt]
single_prompt = [example_system_message + prompt] responses = llm.generate(single_prompt, sampling_params)
responses = llm.generate(single_prompt, sampling_params) ref_output = responses[0].outputs[0].text
ref_output = responses[0].outputs[0].text
# (Probably) Use cascade attention.
# (Probably) Use cascade attention. prompts = [example_system_message + prompt] * 64
prompts = [example_system_message + prompt] * 64 responses = llm.generate(prompts, sampling_params)
responses = llm.generate(prompts, sampling_params) for response in responses:
for response in responses: assert response.outputs[0].text == ref_output
assert response.outputs[0].text == ref_output
...@@ -438,25 +438,26 @@ def test_eagle_correctness( ...@@ -438,25 +438,26 @@ def test_eagle_correctness(
should be the same when using eagle speculative decoding. should be the same when using eagle speculative decoding.
model_setup: (method, model_name, eagle_model_name, tp_size) model_setup: (method, model_name, eagle_model_name, tp_size)
""" """
with monkeypatch.context() as m: # Determine attention config
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN": # Scout requires default backend selection because vision encoder has
# Scout requires default backend selection # head_dim 88 being incompatible with FLASH_ATTN and needs to fall back
# because vision encoder has head_dim 88 being incompatible # to Flex Attn
# with FLASH_ATTN and needs to fall back to Flex Attn if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
if current_platform.is_rocm():
# pass if not ROCm # TODO: Enable Flex Attn for spec_decode on ROCm
if current_platform.is_rocm(): pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
# TODO: Enable Flex Attn for spec_decode on ROCm attention_config = None # Let it fall back to default
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently") else:
else: attention_config = {"backend": attn_backend}
m.setenv("VLLM_MLA_DISABLE", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): with monkeypatch.context() as m:
pytest.skip( m.setenv("VLLM_MLA_DISABLE", "1")
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm(): if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
if "deepseek" in model_setup[1].lower(): if "deepseek" in model_setup[1].lower():
...@@ -471,7 +472,10 @@ def test_eagle_correctness( ...@@ -471,7 +472,10 @@ def test_eagle_correctness(
max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len
ref_llm = LLM( ref_llm = LLM(
model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size model=model_name,
max_model_len=max_model_len,
tensor_parallel_size=tp_size,
attention_config=attention_config,
) )
ref_outputs = ref_llm.chat(test_prompts, sampling_config) ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm del ref_llm
...@@ -492,6 +496,7 @@ def test_eagle_correctness( ...@@ -492,6 +496,7 @@ def test_eagle_correctness(
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill, enable_chunked_prefill=enable_chunked_prefill,
model_impl=model_impl, model_impl=model_impl,
attention_config=attention_config,
) )
spec_outputs = spec_llm.chat(test_prompts, sampling_config) spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0 matches = 0
......
...@@ -11,6 +11,13 @@ from vllm import SamplingParams ...@@ -11,6 +11,13 @@ from vllm import SamplingParams
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ErrorResponse,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -484,6 +491,60 @@ async def test_dp_rank_argument(): ...@@ -484,6 +491,60 @@ async def test_dp_rank_argument():
pass pass
@pytest.mark.asyncio(scope="module")
async def test_header_dp_rank_argument():
with ExitStack() as after:
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
MODEL_NAME = "test-model"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
# Create models first
models = OpenAIServingModels(
engine_client=engine,
base_model_paths=BASE_MODEL_PATHS,
)
# Create serving chat instance
serving_chat = OpenAIServingChat(
engine_client=engine,
models=models,
response_role="assistant",
chat_template=None,
chat_template_content_format="auto",
request_logger=None,
)
# Create a chat completion request
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": TEXT_PROMPT}],
max_tokens=100,
temperature=1.0,
seed=33,
)
# Test 1: Valid DP rank (0)
mock_raw_request = MagicMock()
mock_raw_request.headers = {"X-data-parallel-rank": "0"}
mock_raw_request.state = MagicMock()
# Should succeed with valid rank
response = await serving_chat.create_chat_completion(req, mock_raw_request)
assert isinstance(response, ChatCompletionResponse), (
"Expected a ChatCompletionResponse for valid DP rank"
)
# Test 2: Out-of-range DP rank (1)
mock_raw_request.headers = {"X-data-parallel-rank": "1"}
# should return ErrorResponse for out-of-range rank
response2 = await serving_chat.create_chat_completion(req, mock_raw_request)
assert isinstance(response2, ErrorResponse), (
"Expected an ErrorResponse for out-of-range DP rank"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_health(): async def test_check_health():
"""Test that check_health returns normally for healthy engine """Test that check_health returns normally for healthy engine
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch.cuda
from vllm import LLM, SamplingParams
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
def test_preprocess_error_handling(monkeypatch: pytest.MonkeyPatch):
"""Test that preprocessing errors are handled gracefully."""
assert not torch.cuda.is_initialized(), (
"fork needs to be used for the engine "
"core process and this isn't possible if cuda is already initialized"
)
# Store original method to call for non-failing requests
original_preprocess = EngineCore.preprocess_add_request
# Monkeypatch to make preprocess_add_request raise an exception
# only for requests with "FAIL" in the first token
def conditional_failing_preprocess(self, request: EngineCoreRequest):
# Fail if the first token id is 333
if request.prompt_token_ids and request.prompt_token_ids[0] == 333:
raise ValueError("Simulated preprocessing error!")
return original_preprocess(self, request)
monkeypatch.setattr(
EngineCore, "preprocess_add_request", conditional_failing_preprocess
)
llm = LLM(model=MODEL_NAME)
# Create a failing request by crafting a request with an invalid token
# We need to use a direct approach since LLM.generate tokenizes for us
from vllm.inputs import TokensPrompt
# This should raise an exception due to the preprocessing failure
# Special token id to trigger the failure
failing_prompt = TokensPrompt(prompt_token_ids=[333])
outputs = llm.generate(failing_prompt, SamplingParams(max_tokens=10)) # type: ignore
assert len(outputs) == 1
assert len(outputs[0].outputs[0].token_ids) == 0
assert outputs[0].finished
assert outputs[0].outputs[0].finish_reason == "error"
# Verify the engine is still functional with a normal request
outputs = llm.generate("Hello, my name is", SamplingParams(max_tokens=10))
assert len(outputs) == 1
assert len(outputs[0].outputs[0].token_ids) > 0
assert outputs[0].outputs[0].finish_reason in ("stop", "length")
...@@ -3,21 +3,29 @@ set -xe ...@@ -3,21 +3,29 @@ set -xe
# Parse command line arguments # Parse command line arguments
KV_BUFFER_DEVICE="cuda" # Default to cuda KV_BUFFER_DEVICE="cuda" # Default to cuda
ATTENTION_BACKEND="" # Default to empty (use vllm default)
while [[ $# -gt 0 ]]; do while [[ $# -gt 0 ]]; do
case $1 in case $1 in
--kv_buffer_device) --kv_buffer_device)
KV_BUFFER_DEVICE="$2" KV_BUFFER_DEVICE="$2"
shift 2 shift 2
;; ;;
--attention-backend)
ATTENTION_BACKEND="$2"
shift 2
;;
*) *)
echo "Unknown option $1" echo "Unknown option $1"
echo "Usage: $0 [--kv_buffer_device <cuda|cpu>]" echo "Usage: $0 [--kv_buffer_device <cuda|cpu>] [--attention-backend <backend>]"
exit 1 exit 1
;; ;;
esac esac
done done
echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE" echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
if [[ -n "$ATTENTION_BACKEND" ]]; then
echo "Using attention backend: $ATTENTION_BACKEND"
fi
DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD
if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then
...@@ -148,6 +156,11 @@ run_tests_for_model() { ...@@ -148,6 +156,11 @@ run_tests_for_model() {
--tensor-parallel-size $PREFILLER_TP_SIZE \ --tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config '$KV_CONFIG'" --kv-transfer-config '$KV_CONFIG'"
# Add attention backend config if specified
if [[ -n "$ATTENTION_BACKEND" ]]; then
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
fi
if [ -n "$model_args" ]; then if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args" FULL_CMD="$BASE_CMD $model_args"
else else
...@@ -188,7 +201,12 @@ run_tests_for_model() { ...@@ -188,7 +201,12 @@ run_tests_for_model() {
--block-size ${DECODE_BLOCK_SIZE} \ --block-size ${DECODE_BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--kv-transfer-config '$KV_CONFIG'" --kv-transfer-config '$KV_CONFIG'"
# Add attention backend config if specified
if [[ -n "$ATTENTION_BACKEND" ]]; then
BASE_CMD="${BASE_CMD} --attention-backend=$ATTENTION_BACKEND"
fi
# DP-EP attention mode # DP-EP attention mode
if [[ -z "$DP_EP" ]]; then if [[ -z "$DP_EP" ]]; then
BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE" BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE"
......
...@@ -8,21 +8,24 @@ SCRIPT="v1/kv_connector/nixl_integration/run_accuracy_test.sh" ...@@ -8,21 +8,24 @@ SCRIPT="v1/kv_connector/nixl_integration/run_accuracy_test.sh"
configs=( configs=(
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2" "GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2"
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2" "GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2"
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1"
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case "GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" "GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1) "DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1)
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1)
) )
run_tests() { run_tests() {
local label=$1 local label=$1
local extra_env=$2 local extra_args=$2
echo "=== Running tests (${label}) ===" echo "=== Running tests (${label}) ==="
for cfg in "${configs[@]}"; do for cfg in "${configs[@]}"; do
echo "-> Running with ${cfg} ${extra_env:+and ${extra_env}}" echo "-> Running with ${cfg} ${extra_args:+and ${extra_args}}"
# Use 'env' to safely set variables without eval # Use 'env' to safely set variables without eval
if ! env ${extra_env} ${cfg} bash "${SCRIPT}"; then if ! env ${cfg} bash "${SCRIPT}" ${extra_args}; then
echo "❌ Test failed for config: ${cfg} ${extra_env:+(${extra_env})}" echo "❌ Test failed for config: ${cfg} ${extra_args:+(${extra_args})}"
exit 1 exit 1
fi fi
done done
...@@ -34,8 +37,8 @@ run_tests "default backend" "" ...@@ -34,8 +37,8 @@ run_tests "default backend" ""
# Check if FLASHINFER is set (non-empty) # Check if FLASHINFER is set (non-empty)
if [[ -n "${FLASHINFER:-}" ]]; then if [[ -n "${FLASHINFER:-}" ]]; then
echo "FLASHINFER is set, rerunning with VLLM_ATTENTION_BACKEND=FLASHINFER" echo "FLASHINFER is set, rerunning with --attention-backend FLASHINFER"
run_tests "FLASHINFER backend" "VLLM_ATTENTION_BACKEND=FLASHINFER" run_tests "FLASHINFER backend" "--attention-backend FLASHINFER"
else else
echo "FLASHINFER not set, skipping FLASHINFER runs." echo "FLASHINFER not set, skipping FLASHINFER runs."
fi fi
...@@ -391,6 +391,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): ...@@ -391,6 +391,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._hand_shake_latency = hand_shake_latency self._hand_shake_latency = hand_shake_latency
self.kv_cache_layout = kv_cache_layout self.kv_cache_layout = kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
self.src_xfer_handles_by_block_size = {self.block_size: 1}
def _nixl_handshake( def _nixl_handshake(
self, host: str, port: int, remote_tp_size: int, expected_engine_id: str self, host: str, port: int, remote_tp_size: int, expected_engine_id: str
...@@ -407,22 +409,43 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): ...@@ -407,22 +409,43 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
assert expected_engine_id == self.REMOTE_ENGINE_ID assert expected_engine_id == self.REMOTE_ENGINE_ID
remote_agent_name = self.add_remote_agent( # Adjust remote block length metadata to satisfy heterogeneous TP
NixlAgentMetadata( # invariants enforced during handshake validation.
engine_id=self.REMOTE_ENGINE_ID, remote_block_lens = list(self.block_len_per_layer)
agent_metadata=FakeNixlWrapper.AGENT_METADATA, tp_ratio = self.kv_topo.tp_ratio(remote_tp_size)
kv_caches_base_addr=[0], if remote_tp_size > self.world_size:
device_id=0, # P TP > D TP case, block_len of remote is smaller
num_blocks=1, remote_block_lens = [
block_lens=self.block_len_per_layer, block_len // (-tp_ratio) for block_len in remote_block_lens
# `self.kv_cache_layout` is only forced to HND when vllm engine ]
# is started. We mock HND here. elif remote_tp_size < self.world_size:
kv_cache_layout="HND", remote_block_lens = [
block_size=self.block_size, block_len * tp_ratio for block_len in remote_block_lens
), ]
remote_tp_size=remote_tp_size,
) # When remote tp_size > local tp_size, handshake with multiple
return {0: remote_agent_name} # remote ranks.
num_hanshakes = 1 if tp_ratio > 0 else -tp_ratio
remote_agents: dict[int, str] = {}
for remote_tp_rank in range(num_hanshakes):
remote_agent_name = self.add_remote_agent(
NixlAgentMetadata(
engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=remote_tp_rank,
num_blocks=1,
block_lens=remote_block_lens,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout="HND",
block_size=self.block_size,
),
remote_tp_rank=remote_tp_rank,
remote_tp_size=remote_tp_size,
)
remote_agents[remote_tp_rank] = remote_agent_name
return remote_agents
class TestNixlHandshake: class TestNixlHandshake:
...@@ -453,7 +476,13 @@ class TestNixlHandshake: ...@@ -453,7 +476,13 @@ class TestNixlHandshake:
vllm_config, connector.engine_id, hand_shake_latency=0 vllm_config, connector.engine_id, hand_shake_latency=0
) )
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper) assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3) worker = connector.connector_worker
worker.nixl_wrapper.set_cycles_before_xfer_done(3)
# simulate handshake
worker.dst_xfer_side_handles = {
FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1}
}
worker.kv_cache_layout = "HND"
num_xfers = 4 num_xfers = 4
while True: while True:
# For the same request_id, initiate multiple xfers across different # For the same request_id, initiate multiple xfers across different
...@@ -567,6 +596,171 @@ class TestNixlHandshake: ...@@ -567,6 +596,171 @@ class TestNixlHandshake:
return return
raise TimeoutError("Took too long to complete async handshake.") raise TimeoutError("Took too long to complete async handshake.")
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size(
self, local_tp_size: int, dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations.
"""
vllm_config = create_vllm_config()
local_tp_size = 1
vllm_config.parallel_config.tensor_parallel_size = local_tp_size
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
worker = connector.connector_worker
# Minimal local registration params used by add_remote_agent
worker.slot_size_per_layer = [4096]
worker.block_len_per_layer = [4096 * worker.block_size]
worker.num_blocks = 1
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
worker.src_blocks_data = [(0, worker.block_len_per_layer[0], worker.tp_rank)]
def check_handshake(remote_tp_size: int):
tp_ratio = remote_tp_size // local_tp_size
assert set(remote_agents.keys()) == set(range(tp_ratio))
remote_engine_id = worker.REMOTE_ENGINE_ID
assert worker._tp_size[remote_engine_id] == remote_tp_size
assert -tp_ratio == worker.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
# ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks
assert -tp_ratio in worker.src_xfer_handles_by_tp_ratio
assert len(worker.src_xfer_handles_by_tp_ratio[-tp_ratio]) == tp_ratio
assert remote_engine_id in worker.dst_xfer_side_handles
assert set(worker.dst_xfer_side_handles[remote_engine_id].keys()) == set(
range(tp_ratio)
)
remote_agents = worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=2,
expected_engine_id=worker.REMOTE_ENGINE_ID,
)
check_handshake(2)
# NOTE flexiblity: a second remote with higher number of ranks is
# discovered. This is not a scenario we actively support right now, but
# the connector allows it.
worker.REMOTE_ENGINE_ID = "remote_engine_2"
remote_agents = worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=6,
expected_engine_id=worker.REMOTE_ENGINE_ID,
)
check_handshake(6)
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size_mla(
self, local_tp_size: int, dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations for an MLA model.
"""
vllm_config = create_vllm_config()
d_tp_size = 1
p_tp_size = 2
# Build two separate connectors/workers to emulate P TP=2 ranks.
conn_p0 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
conn_p1 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
conn_p0.connector_worker = FakeNixlConnectorWorker(
vllm_config, conn_p0.engine_id, hand_shake_latency=0
)
conn_p1.connector_worker = FakeNixlConnectorWorker(
vllm_config, conn_p1.engine_id, hand_shake_latency=0
)
# Force P world size to 2 for both workers and emulate distinct tp_ranks.
# Also enable MLA path so that expected_finished_count is updated.
for rank, worker in enumerate(
(conn_p0.connector_worker, conn_p1.connector_worker)
):
worker.world_size = p_tp_size
worker.kv_topo.remote_tp_size = {worker.engine_id: p_tp_size}
worker.tp_rank = rank
worker.use_mla = True
req_id = "req-ep-dp2-p0"
now = time.perf_counter()
# Register a request on P that is waiting for consumers to read
# (both workers track it).
conn_p0.connector_worker._reqs_to_send[req_id] = now + 10.0
conn_p0.connector_worker._reqs_to_process.add(req_id)
conn_p1.connector_worker._reqs_to_send[req_id] = now + 10.0
conn_p1.connector_worker._reqs_to_process.add(req_id)
# Simulate a read notification coming from D with (tp=1, dp=2).
notif = f"{req_id}:{d_tp_size}".encode()
# D0-0->P0 notif
conn_p0.connector_worker.nixl_wrapper.get_new_notifs = lambda: {
"agent": [notif]
} # type: ignore[method-assign]
conn_p1.connector_worker.nixl_wrapper.get_new_notifs = lambda: {
"agent": [notif]
} # type: ignore[method-assign]
# Trigger notification processing via get_finished().
done_sending0, _ = conn_p0.get_finished(finished_req_ids=set())
done_sending1, _ = conn_p1.get_finished(finished_req_ids=set())
assert req_id in done_sending0 and req_id in done_sending1
# E2E aggregation: ensure the aggregated output marks the request
# as finished using the connector's expected_finished_count.
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
aggregator = KVOutputAggregator.from_connector(conn_p0, world_size=2)
out0 = ModelRunnerOutput(
req_ids=[req_id],
req_id_to_index={req_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=done_sending0,
finished_recving=None,
),
)
out1 = ModelRunnerOutput(
req_ids=[req_id],
req_id_to_index={req_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=done_sending1,
finished_recving=None,
),
)
aggregated = aggregator.aggregate([out0, out1], output_rank=0)
assert aggregated.kv_connector_output is not None
assert aggregated.kv_connector_output.finished_sending == {req_id}
# Producers cleaned up state for the finished request.
assert req_id not in conn_p0.connector_worker._reqs_to_send
assert req_id not in conn_p0.connector_worker._reqs_to_process
assert req_id not in conn_p1.connector_worker._reqs_to_send
assert req_id not in conn_p1.connector_worker._reqs_to_process
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
...@@ -585,6 +779,9 @@ class TestNixlHandshake: ...@@ -585,6 +779,9 @@ class TestNixlHandshake:
connector.connector_worker = FakeNixlConnectorWorker( connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id vllm_config, connector.engine_id
) )
# Register (mocked) local xfer handler
# worker = connector.connector_worker
# worker.src_xfer_handles_by_block_size = {worker.block_size: 1}
metadata = NixlConnectorMetadata() metadata = NixlConnectorMetadata()
total_reqs = 5 total_reqs = 5
for i in range(total_reqs): for i in range(total_reqs):
...@@ -672,7 +869,6 @@ class TestNixlHandshake: ...@@ -672,7 +869,6 @@ class TestNixlHandshake:
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
# mismatched layout is expected to fail # mismatched layout is expected to fail
worker.add_remote_agent(meta, remote_tp_size=2) worker.add_remote_agent(meta, remote_tp_size=2)
with pytest.raises(AssertionError):
worker.add_remote_agent(meta, remote_tp_size=1) worker.add_remote_agent(meta, remote_tp_size=1)
@patch( @patch(
...@@ -1132,7 +1328,7 @@ def _run_abort_timeout_test(llm: LLM, timeout: int): ...@@ -1132,7 +1328,7 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
"TRITON_ATTN", "TRITON_ATTN",
], ],
) )
def test_register_kv_caches(dist_init, attn_backend, monkeypatch): def test_register_kv_caches(dist_init, attn_backend):
""" """
Test that register_kv_caches() properly calls nixl_wrapper methods with Test that register_kv_caches() properly calls nixl_wrapper methods with
correct data. correct data.
...@@ -1144,9 +1340,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch): ...@@ -1144,9 +1340,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
block layout info block layout info
""" """
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) vllm_config = create_vllm_config(attention_backend=attn_backend)
vllm_config = create_vllm_config()
# Import the appropriate backend based on the parameter # Import the appropriate backend based on the parameter
if attn_backend == "FLASH_ATTN": if attn_backend == "FLASH_ATTN":
...@@ -1359,8 +1553,11 @@ def test_shutdown_cleans_up_resources(dist_init): ...@@ -1359,8 +1553,11 @@ def test_shutdown_cleans_up_resources(dist_init):
patch.object(nixl_wrapper, "deregister_memory") as mock_dereg, patch.object(nixl_wrapper, "deregister_memory") as mock_dereg,
): ):
worker._recving_transfers = {"req1": [123]} worker._recving_transfers = {"req1": [123]}
worker.src_xfer_side_handle = 456 # Mock register_kv_cache which registers local handle
worker.dst_xfer_side_handles = {"engine1": 789} worker.src_xfer_handles_by_block_size = {worker.block_size: 455}
# P TP = 2 * D TP case, we should register 2 local handles
worker.src_xfer_handles_by_tp_ratio = {-2: [456, 457]}
worker.dst_xfer_side_handles = {"engine1": {0: 789}}
worker._remote_agents = {"engine1": {0: "agent1"}} worker._remote_agents = {"engine1": {0: "agent1"}}
worker._registered_descs = ["desc1", "desc2"] worker._registered_descs = ["desc1", "desc2"]
...@@ -1381,8 +1578,10 @@ def test_shutdown_cleans_up_resources(dist_init): ...@@ -1381,8 +1578,10 @@ def test_shutdown_cleans_up_resources(dist_init):
mock_listener.join.assert_called_once() mock_listener.join.assert_called_once()
mock_rel_xfer.assert_called_once_with(123) mock_rel_xfer.assert_called_once_with(123)
assert mock_rel_dlist.call_count == 2 assert mock_rel_dlist.call_count == 4
mock_rel_dlist.assert_any_call(456) # src handle mock_rel_dlist.assert_any_call(455) # src handle (whole region)
mock_rel_dlist.assert_any_call(456) # src handle (1st chunk)
mock_rel_dlist.assert_any_call(457) # src handle (2nd chunk)
mock_rel_dlist.assert_any_call(789) # dst handle mock_rel_dlist.assert_any_call(789) # dst handle
mock_rem_agent.assert_called_once_with("agent1") mock_rem_agent.assert_called_once_with("agent1")
assert mock_dereg.call_count == 2 assert mock_dereg.call_count == 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment