Commit 8d75f22e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori

parents ce888aa4 7d80c73d
......@@ -11,12 +11,15 @@ from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
skip_unsupported = pytest.mark.skipif(
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
reason="Requires CUDA and >= Hopper (SM90)",
not (current_platform.is_cuda() and current_platform.has_device_capability(80)),
# Supports testing on Ampere and Ada Lovelace devices.
# Note: For devices with SM < 90, batch invariance does not support CUDA Graphs.
reason="Requires CUDA and >= Ampere (SM80)",
)
BACKENDS: list[str] = [
"FLASH_ATTN",
"TRITON_MLA",
]
if has_flashinfer():
......@@ -96,3 +99,7 @@ def _extract_step_logprobs(request_output):
return t, inner.token_ids
return None, None
def is_device_capability_below_90() -> bool:
return not current_platform.has_device_capability(90)
......@@ -9,10 +9,22 @@ correctly with the DeepSeek-V2-Lite model using GSM8K evaluation.
"""
import pytest
import torch
from tests.evals.gsm8k.gsm8k_eval import evaluate_gsm8k
from tests.utils import RemoteOpenAIServer
# Detect Blackwell / B200 (compute capability 10.x)
try:
if torch.cuda.is_available():
cap = torch.cuda.get_device_capability(0)
IS_BLACKWELL = cap[0] >= 10
else:
IS_BLACKWELL = False
except Exception:
# Be conservative: if we can't detect, don't xfail by default
IS_BLACKWELL = False
MODEL_NAME = "deepseek-ai/DeepSeek-V2-Lite-Chat"
DP_SIZE = 2
......@@ -33,6 +45,13 @@ DEEPEP_BACKENDS = [
@pytest.mark.parametrize("all2all_backend", DEEPEP_BACKENDS)
@pytest.mark.xfail(
IS_BLACKWELL,
reason=(
"Temporary: DBO accuracy unstable on Blackwell "
"(doesn't meet expectation of MIN_ACCURACY = 0.62)"
),
)
def test_dbo_dp_ep_gsm8k(all2all_backend: str, num_gpus_available):
"""
Test DBO with DP+EP using GSM8K evaluation.
......
......@@ -124,6 +124,8 @@ def run_tests(
with monkeypatch.context() as m:
# avoid precision errors
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
# lock matmul precision to full FP32
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "highest")
# m.setenv("VLLM_BATCH_INVARIANT", "1")
outputs: list[tuple[str, list, list]] = []
for n, (
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test that verifies no implicit GPU-CPU synchronization occurs during
speculative decoding generation under expected conditions.
"""
import multiprocessing
import sys
import traceback
import pytest
import torch
@pytest.fixture
def sync_tracker():
"""
Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect
lazy init syncs. Prints stack traces immediately when syncs occur.
"""
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
# Shared counter for cross-process communication (inherited by fork)
sync_count = multiprocessing.Value("i", 0)
# Save original property
original_prop = CommonAttentionMetadata.seq_lens_cpu
original_fget = original_prop.fget
# Create tracking wrapper
def tracking_seq_lens_cpu(self):
if self._seq_lens_cpu is None:
# Increment counter
with sync_count.get_lock():
sync_count.value += 1
count = sync_count.value
# Print stack trace immediately (shows in subprocess output)
print(f"\n{'=' * 60}", file=sys.stderr)
print(f"SYNC #{count}: seq_lens_cpu lazy init triggered!", file=sys.stderr)
print(f"{'=' * 60}", file=sys.stderr)
traceback.print_stack(file=sys.stderr)
print(f"{'=' * 60}\n", file=sys.stderr)
sys.stderr.flush()
return original_fget(self)
# Apply patch
CommonAttentionMetadata.seq_lens_cpu = property(tracking_seq_lens_cpu)
class SyncTracker:
@property
def count(self) -> int:
return sync_count.value
def assert_no_sync(self, msg: str = ""):
count = sync_count.value
assert count == 0, (
f"Unexpected GPU-CPU sync: seq_lens_cpu lazy init triggered "
f"{count} times. See stack traces above. {msg}"
)
yield SyncTracker()
# Restore original property
CommonAttentionMetadata.seq_lens_cpu = original_prop
torch._dynamo.reset()
# Test configurations: (model, spec_model, method, num_spec_tokens, backend_env)
SPEC_DECODE_CONFIGS = [
pytest.param(
"meta-llama/Llama-3.2-1B-Instruct",
"nm-testing/Llama3_2_1B_speculator.eagle3",
"eagle3",
2,
id="eagle3-llama",
),
pytest.param(
"eagle618/deepseek-v3-random",
"eagle618/eagle-deepseek-v3-random",
"eagle",
2,
id="eagle-mla-deepseek",
),
]
@pytest.mark.parametrize(
"model,spec_model,method,num_spec_tokens",
SPEC_DECODE_CONFIGS,
)
def test_no_sync_with_spec_decode(
sync_tracker,
model: str,
spec_model: str,
method: str,
num_spec_tokens: int,
):
"""
Test that no implicit GPU-CPU sync occurs during speculative decoding
generation.
"""
# Import vLLM AFTER sync_tracker fixture has applied the patch
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
llm = LLM(
model=model,
max_model_len=256,
speculative_config={
"method": method,
"num_speculative_tokens": num_spec_tokens,
"model": spec_model,
},
enforce_eager=True,
async_scheduling=True,
)
outputs = llm.generate(
["Hello, my name is"],
SamplingParams(temperature=0, max_tokens=10),
)
assert len(outputs) == 1
assert len(outputs[0].outputs[0].text) > 0
del llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
sync_tracker.assert_no_sync()
......@@ -7,6 +7,7 @@ import pytest
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationMode
from vllm.platforms import current_platform
from ...utils import check_answers, fork_new_process_for_each_test, prep_prompts
......@@ -43,15 +44,26 @@ def test_prompts():
return prompts
@fork_new_process_for_each_test
use_fork_for_test = (
fork_new_process_for_each_test if not current_platform.is_rocm() else lambda x: x
)
@use_fork_for_test
@pytest.mark.parametrize("kv_sharing_fast_prefill", [False, True])
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_kv_sharing_fast_prefill(
monkeypatch: pytest.MonkeyPatch,
kv_sharing_fast_prefill: bool,
enforce_eager: bool,
test_prompts: list[str],
):
if not enforce_eager and current_platform.is_rocm():
# Relevant context: https://github.com/vllm-project/vllm/pull/29244
pytest.skip(
"ROCm: torch.compile produces incorrect output for gemma-3n's GELU "
"with tanh approximation. Use enforce_eager=True instead."
)
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
compilation_config = CompilationConfig(
# This allows vLLM compilation backend to handle allocating and
......@@ -65,6 +77,10 @@ def test_kv_sharing_fast_prefill(
with monkeypatch.context() as m:
# Make scheduling deterministic for reproducibility
if current_platform.is_rocm():
# Use spawn to prevent cuda re-initialization error
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
else:
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
prompts, answer, indices = prep_prompts(batch_size)
......
......@@ -191,8 +191,8 @@ def test_suffix_decoding_acceptance(
# Expect the acceptance rate to improve.
assert first_accept_rate < last_accept_rate
# Heuristic: expect at least 82.5% acceptance rate at the end.
assert last_accept_rate > 0.825
# Heuristic: expect at least 80.0% acceptance rate at the end.
assert last_accept_rate > 0.80
del spec_llm
torch.cuda.empty_cache()
......@@ -402,7 +402,11 @@ def test_eagle_correctness(
# Scout requires default backend selection
# because vision encoder has head_dim 88 being incompatible
# with FLASH_ATTN and needs to fall back to Flex Attn
pass
# pass if not ROCm
if current_platform.is_rocm():
# TODO: Enable Flex Attn for spec_decode on ROCm
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
else:
m.setenv("VLLM_MLA_DISABLE", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
......@@ -413,9 +417,9 @@ def test_eagle_correctness(
"multi-token eagle spec decode on current platform"
)
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
if "deepseek" in model_setup[1].lower():
pytest.skip("FLASH_ATTN for deepseek not supported on ROCm platform")
pytest.skip("ROCM_AITER_FA for deepseek not supported on ROCm platform")
else:
m.setenv("VLLM_ROCM_USE_AITER", "1")
......
......@@ -148,7 +148,7 @@ run_epd_1e_1pd() {
--max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector",
"ec_connector": "ECExampleConnector",
"ec_role": "ec_producer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
......@@ -167,7 +167,7 @@ run_epd_1e_1pd() {
--max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector",
"ec_connector": "ECExampleConnector",
"ec_role": "ec_consumer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
......@@ -348,7 +348,7 @@ run_epd_1e_1p_1d() {
--max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector",
"ec_connector": "ECExampleConnector",
"ec_role": "ec_producer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
......@@ -369,7 +369,7 @@ run_epd_1e_1p_1d() {
--max-num-seqs 128 \
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
--ec-transfer-config '{
"ec_connector": "ECSharedStorageConnector",
"ec_connector": "ECExampleConnector",
"ec_role": "ec_consumer",
"ec_connector_extra_config": {
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for ECSharedStorageConnector.
Unit tests for ECExampleConnector.
"""
import os
......@@ -13,9 +13,9 @@ import torch
from vllm.config import VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorRole
from vllm.distributed.ec_transfer.ec_connector.shared_storage_connector import (
ECSharedStorageConnector,
ECSharedStorageConnectorMetadata,
from vllm.distributed.ec_transfer.ec_connector.example_connector import (
ECExampleConnector,
ECExampleConnectorMetadata,
MMMeta,
)
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
......@@ -81,12 +81,12 @@ def mock_request_with_3_mm():
# ------------------ Unit Tests ------------------ #
class TestECSharedStorageConnectorBasics:
class TestECExampleConnectorBasics:
"""Test basic EC connector functionality."""
def test_initialization_producer(self, mock_vllm_config_producer, temp_storage):
"""Test connector initializes correctly as producer."""
connector = ECSharedStorageConnector(
connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER,
)
......@@ -98,7 +98,7 @@ class TestECSharedStorageConnectorBasics:
def test_initialization_consumer(self, mock_vllm_config_consumer, temp_storage):
"""Test connector initializes correctly as consumer."""
connector = ECSharedStorageConnector(
connector = ECExampleConnector(
vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER,
)
......@@ -109,11 +109,11 @@ class TestECSharedStorageConnectorBasics:
def test_role_assignment(self, mock_vllm_config_producer):
"""Test role is correctly assigned."""
scheduler_connector = ECSharedStorageConnector(
scheduler_connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER,
)
worker_connector = ECSharedStorageConnector(
worker_connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER,
)
......@@ -133,7 +133,7 @@ class TestCacheExistence:
):
"""Test has_caches returns True when all 3 caches exist."""
# Test for producer first
producer = ECSharedStorageConnector(
producer = ECExampleConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER,
)
......@@ -154,7 +154,7 @@ class TestCacheExistence:
assert all(producer_result), f"Expected all True, got {producer_result}"
# Also test consumer can check if cache exists
consumer = ECSharedStorageConnector(
consumer = ECExampleConnector(
vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.SCHEDULER,
)
......@@ -170,7 +170,7 @@ class TestCacheExistence:
self, mock_vllm_config_producer, mock_request_with_3_mm
):
"""Test has_caches returns False when no caches exist."""
connector = ECSharedStorageConnector(
connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER,
)
......@@ -186,7 +186,7 @@ class TestCacheExistence:
self, mock_vllm_config_producer, mock_request_with_3_mm
):
"""Test has_caches with some caches existing (1 of 3)."""
connector = ECSharedStorageConnector(
connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER,
)
......@@ -213,7 +213,7 @@ class TestStateManagement:
self, mock_vllm_config_producer, mock_request_with_3_mm
):
"""Test state update after allocation for 3 MM items."""
connector = ECSharedStorageConnector(
connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER,
)
......@@ -238,7 +238,7 @@ class TestStateManagement:
self, mock_vllm_config_producer, mock_request_with_3_mm
):
"""Test metadata building for 3 MM items."""
connector = ECSharedStorageConnector(
connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER,
)
......@@ -252,7 +252,7 @@ class TestStateManagement:
metadata = connector.build_connector_meta(scheduler_output)
# Assert
assert isinstance(metadata, ECSharedStorageConnectorMetadata)
assert isinstance(metadata, ECExampleConnectorMetadata)
assert len(metadata.mm_datas) == 3
assert metadata.mm_datas[0].mm_hash == "img_hash_1"
assert metadata.mm_datas[0].num_token == 100
......@@ -266,7 +266,7 @@ class TestStateManagement:
def test_build_connector_meta_empty(self, mock_vllm_config_producer):
"""Test metadata building with empty state."""
connector = ECSharedStorageConnector(
connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER,
)
......@@ -274,14 +274,14 @@ class TestStateManagement:
scheduler_output = Mock(spec=SchedulerOutput)
metadata = connector.build_connector_meta(scheduler_output)
assert isinstance(metadata, ECSharedStorageConnectorMetadata)
assert isinstance(metadata, ECExampleConnectorMetadata)
assert len(metadata.mm_datas) == 0
def test_state_cleared_after_metadata_build(
self, mock_vllm_config_producer, mock_request_with_3_mm
):
"""Test that state is properly cleared after building metadata."""
connector = ECSharedStorageConnector(
connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER,
)
......@@ -310,7 +310,7 @@ class TestCacheSaving:
self, mock_vllm_config_producer, mock_request_with_3_mm, temp_storage
):
"""Test cache saving as producer for 3 different MM items."""
connector = ECSharedStorageConnector(
connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER,
)
......@@ -336,7 +336,7 @@ class TestCacheSaving:
def test_save_caches_consumer_skips(self, mock_vllm_config_consumer):
"""Test cache saving is skipped for consumer."""
connector = ECSharedStorageConnector(
connector = ECExampleConnector(
vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER,
)
......@@ -366,7 +366,7 @@ class TestCacheLoading:
):
"""Test consumer loads 3 caches from storage."""
# First, create producer to save caches
producer = ECSharedStorageConnector(
producer = ECExampleConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER,
)
......@@ -379,13 +379,13 @@ class TestCacheLoading:
producer.save_caches(saved_caches, mm_hash)
# Now consumer loads
consumer = ECSharedStorageConnector(
consumer = ECExampleConnector(
vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER,
)
# Setup metadata for all 3
metadata = ECSharedStorageConnectorMetadata()
metadata = ECExampleConnectorMetadata()
for mm_hash in mm_hashes:
metadata.add_mm_data(MMMeta.make_meta(mm_hash, 100))
consumer.bind_connector_metadata(metadata)
......@@ -410,7 +410,7 @@ class TestCacheLoading:
):
"""Test cache loading skips already cached items."""
# Setup: producer saves cache
producer = ECSharedStorageConnector(
producer = ECExampleConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER,
)
......@@ -420,12 +420,12 @@ class TestCacheLoading:
producer.save_caches({mm_hash: saved_cache}, mm_hash)
# Consumer setup
consumer = ECSharedStorageConnector(
consumer = ECExampleConnector(
vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER,
)
metadata = ECSharedStorageConnectorMetadata()
metadata = ECExampleConnectorMetadata()
metadata.add_mm_data(MMMeta.make_meta(mm_hash, 100))
consumer.bind_connector_metadata(metadata)
......@@ -444,13 +444,13 @@ class TestCacheLoading:
def test_start_load_caches_empty_metadata(self, mock_vllm_config_consumer):
"""Test loading with empty metadata does nothing."""
consumer = ECSharedStorageConnector(
consumer = ECExampleConnector(
vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER,
)
# Setup empty metadata
metadata = ECSharedStorageConnectorMetadata()
metadata = ECExampleConnectorMetadata()
consumer.bind_connector_metadata(metadata)
# Load (should not raise)
......@@ -466,7 +466,7 @@ class TestFilenameGeneration:
def test_generate_foldername(self, mock_vllm_config_producer, temp_storage):
"""Test folder name generation."""
connector = ECSharedStorageConnector(
connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER,
)
......@@ -479,7 +479,7 @@ class TestFilenameGeneration:
def test_generate_filename(self, mock_vllm_config_producer, temp_storage):
"""Test filename generation."""
connector = ECSharedStorageConnector(
connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER,
)
......@@ -493,7 +493,7 @@ class TestFilenameGeneration:
def test_generate_filename_consistency(self, mock_vllm_config_producer):
"""Test filename generation is consistent."""
connector = ECSharedStorageConnector(
connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER,
)
......@@ -510,12 +510,12 @@ class TestMetadataBindingLifecycle:
def test_bind_connector_metadata(self, mock_vllm_config_consumer):
"""Test binding connector metadata."""
connector = ECSharedStorageConnector(
connector = ECExampleConnector(
vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER,
)
metadata = ECSharedStorageConnectorMetadata()
metadata = ECExampleConnectorMetadata()
metadata.add_mm_data(MMMeta.make_meta("hash_1", 100))
connector.bind_connector_metadata(metadata)
......@@ -524,12 +524,12 @@ class TestMetadataBindingLifecycle:
def test_clear_connector_metadata(self, mock_vllm_config_consumer):
"""Test clearing connector metadata."""
connector = ECSharedStorageConnector(
connector = ECExampleConnector(
vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER,
)
metadata = ECSharedStorageConnectorMetadata()
metadata = ECExampleConnectorMetadata()
connector.bind_connector_metadata(metadata)
connector.clear_connector_metadata()
......@@ -538,12 +538,12 @@ class TestMetadataBindingLifecycle:
def test_get_connector_metadata(self, mock_vllm_config_consumer):
"""Test getting connector metadata."""
connector = ECSharedStorageConnector(
connector = ECExampleConnector(
vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER,
)
metadata = ECSharedStorageConnectorMetadata()
metadata = ECExampleConnectorMetadata()
connector.bind_connector_metadata(metadata)
retrieved = connector._get_connector_metadata()
......@@ -552,7 +552,7 @@ class TestMetadataBindingLifecycle:
def test_get_connector_metadata_not_set(self, mock_vllm_config_consumer):
"""Test getting metadata when not set raises."""
connector = ECSharedStorageConnector(
connector = ECExampleConnector(
vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER,
)
......@@ -566,7 +566,7 @@ class TestEdgeCases:
def test_save_empty_cache(self, mock_vllm_config_producer):
"""Test saving empty tensor."""
connector = ECSharedStorageConnector(
connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.WORKER,
)
......@@ -579,12 +579,12 @@ class TestEdgeCases:
def test_load_nonexistent_cache(self, mock_vllm_config_consumer):
"""Test loading cache that doesn't exist raises error."""
connector = ECSharedStorageConnector(
connector = ECExampleConnector(
vllm_config=mock_vllm_config_consumer,
role=ECConnectorRole.WORKER,
)
metadata = ECSharedStorageConnectorMetadata()
metadata = ECExampleConnectorMetadata()
metadata.add_mm_data(MMMeta.make_meta("nonexistent_hash", 100))
connector.bind_connector_metadata(metadata)
......@@ -596,7 +596,7 @@ class TestEdgeCases:
def test_has_caches_empty_request(self, mock_vllm_config_producer):
"""Test has_caches with request that has no MM data."""
connector = ECSharedStorageConnector(
connector = ECExampleConnector(
vllm_config=mock_vllm_config_producer,
role=ECConnectorRole.SCHEDULER,
)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test for the fix in PR #29987: Eagerly abort cancelled final-step requests.
This test verifies that when a request is aborted during its final execution
step (when it would naturally complete), it is properly marked as aborted
rather than being treated as normally completed.
The test uses a dummy KV connector to verify that the connector receives
the correct finish status (FINISHED_ABORTED, not FINISHED_LENGTH_CAPPED).
"""
import asyncio
import tempfile
import time
from pathlib import Path
from typing import Any
from unittest.mock import patch
import pytest
from vllm import SamplingParams
from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
TEXT_PROMPT = "Hello"
class DummyKVConnectorMetadata(KVConnectorMetadata):
"""Dummy metadata for the test connector."""
def __init__(self):
self.requests: list = []
class DummyKVConnector(KVConnectorBase_V1):
"""
Dummy KV connector that captures request finish statuses to a file.
This is used to verify the fix - without the fix, a request aborted
during its final step would be captured as FINISHED_LENGTH_CAPPED
instead of FINISHED_ABORTED.
The connector runs in a separate process, so we write statuses to a file
that can be read by the test process.
"""
def __init__(
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: KVCacheConfig | None = None,
):
super().__init__(vllm_config, role, kv_cache_config)
# Get the status file path from extra config
extra_config = vllm_config.kv_transfer_config.kv_connector_extra_config or {}
self.status_file = extra_config.get("status_file")
# Log that we were initialized
if self.status_file:
try:
with open(self.status_file, "a") as f:
f.write(f"INIT:{role.name}\n")
except Exception:
pass
def get_num_new_matched_tokens(
self,
request: Request,
num_computed_tokens: int,
) -> tuple[int | None, bool]:
return (0, False)
def update_state_after_alloc(
self,
request: Request,
blocks: Any,
num_external_tokens: int,
):
pass
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
return DummyKVConnectorMetadata()
def request_finished(
self,
request: Request,
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
"""Capture the request status when finished by writing to a file."""
if self.status_file:
try:
with open(self.status_file, "a") as f:
# Write the status name (e.g., "FINISHED_ABORTED")
f.write(f"{request.status.name}\n")
except Exception as e:
# Log but don't fail - this is just test instrumentation
print(f"[DummyKVConnector] Failed to write status: {e}")
return False, None
def start_load_kv(self, forward_context: Any, **kwargs: Any) -> None:
pass
def wait_for_layer_load(self, layer_name: str) -> None:
pass
def save_kv_layer(
self,
layer_name: str,
kv_layer: Any,
attn_metadata: Any,
**kwargs: Any,
) -> None:
pass
def wait_for_save(self):
pass
# Register the dummy connector
KVConnectorFactory.register_connector(
"DummyKVConnector", __name__, DummyKVConnector.__name__
)
@pytest.mark.parametrize("async_scheduling", [False, True])
@pytest.mark.asyncio
async def test_abort_during_final_step(async_scheduling: bool):
"""
Test that a request aborted during its final execution step is treated as
aborted rather than completed.
This test:
1. Monkeypatches execute_model to wait for a file to be deleted
2. Configures a dummy KV connector to capture finish statuses
3. Starts a request with max_tokens=1 (will complete on first decode step)
4. Aborts the request, then deletes the file to unblock execute_model
5. Verifies the KV connector received FINISHED_ABORTED not FINISHED_LENGTH_CAPPED
See https://github.com/vllm-project/vllm/pull/29987.
Without the fix, the KV connector would see FINISHED_LENGTH_CAPPED because
update_from_output() would mark the request as completed before processing
the abort. This causes KV cache blocks to not be freed properly in
disaggregated prefill scenarios.
With the fix, _process_aborts_queue() runs before update_from_output(), so the
abort takes precedence and the KV connector sees FINISHED_ABORTED.
"""
# Create three temporary files:
# 1. ready_file: deleted by execute_model to signal it has started
# 2. block_file: execute_model waits for this to be deleted
# 3. status_file: KV connector writes finish statuses here
with tempfile.NamedTemporaryFile(delete=False) as f:
ready_file = Path(f.name)
with tempfile.NamedTemporaryFile(delete=False) as f2:
block_file = Path(f2.name)
with tempfile.NamedTemporaryFile(delete=False, mode="w") as f3:
status_file = Path(f3.name)
try:
# Get the original execute_model method
from vllm.v1.worker.gpu_worker import Worker
original_execute_model = Worker.execute_model
def execute_model_with_wait(self, scheduler_output):
# Signal that execute_model has been called by deleting ready_file
if ready_file.exists():
ready_file.unlink()
# Wait for the block file to be deleted (triggered from test after abort)
# This runs in the worker process (after fork), so we poll the filesystem
while block_file.exists():
time.sleep(0.01)
return original_execute_model(self, scheduler_output)
# Patch execute_model to inject the wait
# This happens before the worker process is forked, so the patch applies there
with patch.object(Worker, "execute_model", execute_model_with_wait):
request_id = "test-abort-final-step"
# Configure engine with dummy KV connector
# Pass the status file path so the connector can write to it
kv_transfer_config = KVTransferConfig(
kv_connector="DummyKVConnector",
kv_role="kv_both",
kv_connector_extra_config={"status_file": str(status_file)},
)
engine_args = AsyncEngineArgs(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
async_scheduling=async_scheduling,
kv_transfer_config=kv_transfer_config,
)
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(engine_args)
try:
# Create a request that will complete after just 1 token
sampling_params = SamplingParams(
max_tokens=1,
ignore_eos=True,
output_kind=RequestOutputKind.DELTA,
)
# Start generation in a task
outputs = []
async def generate():
async for output in engine.generate(
request_id=request_id,
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
):
outputs.append(output)
gen_task = asyncio.create_task(generate())
# Wait for execute_model to signal it has started (with timeout)
timeout = 5.0 # 5 second timeout
start_time = time.time()
while ready_file.exists():
if time.time() - start_time > timeout:
raise TimeoutError(
"Timeout waiting for execute_model to start. "
"The monkeypatch may not be working correctly, "
"for example if spawn was used instead of fork."
)
await asyncio.sleep(0.01)
# Abort the request while execute_model is blocked
await engine.abort(request_id)
# Now unblock execute_model by deleting the file
# The abort should be processed before the model output
block_file.unlink()
# Wait for generation to complete
await gen_task
# Give the scheduler a moment to finish cleanup
await asyncio.sleep(0.1)
# Verify we got output
assert len(outputs) > 0, "Should have received at least one output"
# The final output should have finish_reason="abort"
final_output = outputs[-1]
assert final_output.finished, (
"Final output should be marked as finished"
)
assert final_output.outputs[0].finish_reason == "abort", (
f"Expected finish_reason='abort' but got "
f"'{final_output.outputs[0].finish_reason}'. "
)
with open(status_file) as f4:
status_lines = f4.read().strip().split("\n")
# Filter for actual finish statuses (not INIT or empty lines)
captured_statuses = [
line
for line in status_lines
if line and line.startswith("FINISHED_")
]
assert len(captured_statuses) >= 1, (
f"Expected at least 1 captured finish status, got "
f"{len(captured_statuses)}. File content: {status_lines}"
)
assert "FINISHED_ABORTED" in captured_statuses, (
f"KV connector should see FINISHED_ABORTED but got "
f"{captured_statuses}. "
)
# Verify cleanup
assert not engine.output_processor.has_unfinished_requests()
finally:
# Shutdown the engine
engine.shutdown()
finally:
# Clean up temporary files if they still exist
if ready_file.exists():
ready_file.unlink()
if block_file.exists():
block_file.unlink()
if status_file.exists():
status_file.unlink()
......@@ -9,6 +9,7 @@ from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.hashing import _xxhash
def test_prefix_caching_from_cli():
......@@ -48,6 +49,21 @@ def test_prefix_caching_from_cli():
args = parser.parse_args(["--prefix-caching-hash-algo", "invalid"])
@pytest.mark.skipif(_xxhash is None, reason="xxhash not installed")
def test_prefix_caching_xxhash_from_cli():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
# set hash algorithm to xxhash (pickle)
args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash"
# set hash algorithm to xxhash_cbor
args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash_cbor"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash_cbor"
def test_defaults_with_usage_context():
engine_args = EngineArgs(model="facebook/opt-125m")
vllm_config: VllmConfig = engine_args.create_engine_config(UsageContext.LLM_CLASS)
......
......@@ -507,7 +507,7 @@ def test_encoder_instance_zero_kv_cache(
)
kv_transfer_config = (
KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_connector="ExampleConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
)
......@@ -515,7 +515,7 @@ def test_encoder_instance_zero_kv_cache(
else None
)
ec_transfer_config = ECTransferConfig(
ec_connector="ECSharedStorageConnector",
ec_connector="ECExampleConnector",
ec_role=ec_role,
ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test_encoder"},
)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import io
import json
import openai # use the official client for correctness check
......@@ -13,6 +11,7 @@ from transformers import AutoConfig
from tests.conftest import ImageTestAssets
from tests.utils import RemoteOpenAIServer
from vllm.utils.serial_utils import tensor2base64
# any model with a chat template should work here
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
......@@ -50,18 +49,6 @@ async def client_with_image_embeds(server_with_image_embeds):
yield async_client
def encode_image_embedding_to_base64(image_embedding) -> str:
"""
Encode image embedding to base64 string
"""
buffer = io.BytesIO()
torch.save(image_embedding, buffer)
buffer.seek(0)
binary_data = buffer.read()
base64_image_embedding = base64.b64encode(binary_data).decode("utf-8")
return base64_image_embedding
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("dtype", [torch.half, torch.float16, torch.float32])
......@@ -73,7 +60,7 @@ async def test_completions_with_image_embeds(
):
# Test case: Single image embeds input
image_embeds = image_assets[0].image_embeds.to(dtype=dtype)
base64_image_embedding = encode_image_embedding_to_base64(image_embeds)
base64_image_embedding = tensor2base64(image_embeds)
chat_completion = await client_with_image_embeds.chat.completions.create(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
......
......@@ -30,7 +30,14 @@ async def lifespan(app: FastAPI):
prefiller_base_url = f"http://{host}:{port}/v1"
app.state.prefill_clients.append(
{
"client": httpx.AsyncClient(timeout=None, base_url=prefiller_base_url),
"client": httpx.AsyncClient(
timeout=None,
base_url=prefiller_base_url,
limits=httpx.Limits(
max_connections=None,
max_keepalive_connections=None,
),
),
"host": host,
"port": port,
"id": i,
......@@ -42,7 +49,14 @@ async def lifespan(app: FastAPI):
decoder_base_url = f"http://{host}:{port}/v1"
app.state.decode_clients.append(
{
"client": httpx.AsyncClient(timeout=None, base_url=decoder_base_url),
"client": httpx.AsyncClient(
timeout=None,
base_url=decoder_base_url,
limits=httpx.Limits(
max_connections=None,
max_keepalive_connections=None,
),
),
"host": host,
"port": port,
"id": i,
......@@ -169,6 +183,10 @@ async def send_request_to_service(
)
response.raise_for_status()
# read/consume the response body to release the connection
# otherwise, it would http.ReadError
await response.aread()
return response
......@@ -206,6 +224,7 @@ async def _handle_completions(api: str, request: Request):
# Extract the needed fields
response_json = response.json()
await response.aclose() # CRITICAL: Release connection back to pool
kv_transfer_params = response_json.get("kv_transfer_params", {})
if kv_transfer_params:
req_data["kv_transfer_params"] = kv_transfer_params
......
......@@ -218,12 +218,12 @@ def test_internal_connector_uses_new_signature():
Test that internal connectors (registered in factory) always use the new
signature and get kv_cache_config.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import (
SharedStorageConnector,
from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import (
ExampleConnector,
)
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "SharedStorageConnector"
vllm_config.kv_transfer_config.kv_connector = "ExampleConnector"
scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config
......@@ -233,7 +233,7 @@ def test_internal_connector_uses_new_signature():
)
assert connector is not None
assert isinstance(connector, SharedStorageConnector)
assert isinstance(connector, ExampleConnector)
assert connector._kv_cache_config is not None
assert connector._kv_cache_config == kv_cache_config
......
......@@ -3,12 +3,14 @@
from dataclasses import asdict
from typing import NamedTuple
import pytest
from PIL import Image
from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import KVTransferConfig
from vllm.multimodal.utils import encode_image_base64
from vllm.platforms import current_platform
MODEL_NAME = "RedHatAI/Qwen2.5-VL-3B-Instruct-quantized.w8a8"
......@@ -108,18 +110,25 @@ def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]):
print("-" * 50)
@pytest.mark.skipif(
current_platform.is_rocm(),
reason=(
"hipErrorLaunchFailure when running this test, see issue:"
"https://github.com/ROCm/pytorch/issues/2822"
),
)
def test_shared_storage_connector_hashes(tmp_path):
"""
Tests that SharedStorageConnector saves KV to the storage locations
Tests that ExampleConnector saves KV to the storage locations
with proper hashes; that are unique for inputs with identical text but
different images (same size), or same multiple images but different orders.
"""
# Using tmp_path as the storage path to store KV
print(f"KV storage path at: {str(tmp_path)}")
# Configure the SharedStorageConnector
# Configure the ExampleConnector
kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_connector="ExampleConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": str(tmp_path)},
)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa: E501
SharedStorageConnectorMetadata,
from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( # noqa: E501
ExampleConnectorMetadata,
)
from vllm.distributed.kv_transfer.kv_transfer_state import (
ensure_kv_transfer_initialized,
......@@ -11,7 +11,7 @@ from vllm.distributed.kv_transfer.kv_transfer_state import (
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
# Importing utils registers TestSharedStorageConnector with the factory
# Importing utils registers TestExampleConnector with the factory
from .utils import create_vllm_config
......@@ -26,13 +26,13 @@ def _make_empty_scheduler_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
kv_connector_metadata=SharedStorageConnectorMetadata(),
kv_connector_metadata=ExampleConnectorMetadata(),
)
def test_kv_connector_mixin_clears_metadata():
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "TestSharedStorageConnector"
vllm_config.kv_transfer_config.kv_connector = "TestExampleConnector"
vllm_config.kv_transfer_config.kv_role = "kv_both"
vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = "unit"
......
......@@ -64,22 +64,6 @@ def test_multimodal_interface():
assumes(PlaceholderRange, "offset")
assumes(PlaceholderRange, "length")
# test a minimal case
import torch
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils import (
apply_mm_hashes_to_token_ids,
)
token_ids = torch.arange(10, dtype=torch.long)
mm_hashes = ["0000", "1111"] # hex repr of 0 and 4369
mm_positions = [
PlaceholderRange(offset=0, length=4),
PlaceholderRange(offset=5, length=4),
]
apply_mm_hashes_to_token_ids(token_ids, mm_hashes, mm_positions)
assert token_ids.tolist() == [0, 0, 0, 0, 4, 4369, 4369, 4369, 4369, 9]
@pytest.mark.skipif(
current_platform.is_rocm(), reason="Requires libcudart.so, not available on ROCm"
......@@ -122,16 +106,6 @@ def test_config_interface():
assumes(CacheConfig, "block_size")
assumes(CacheConfig, "gpu_memory_utilization")
# mla metadata minimal cases
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils import (
mla_enabled,
)
model_config = ModelConfig(model="deepseek-ai/DeepSeek-R1")
assert mla_enabled(model_config)
model_config = ModelConfig(model="Qwen/Qwen3-0.6B")
assert not mla_enabled(model_config)
# kv metadata minimal case
from vllm.utils.torch_utils import get_kv_cache_torch_dtype
......@@ -139,7 +113,7 @@ def test_config_interface():
parallel_config = ParallelConfig()
cache_config = CacheConfig(cache_dtype="bfloat16")
kv_dtype = get_kv_cache_torch_dtype(cache_config.cache_dtype, model_config.dtype)
use_mla = mla_enabled(model_config)
use_mla = False
chunk_size = 256
num_layer = model_config.get_num_layers(parallel_config)
num_kv_head = model_config.get_num_kv_heads(parallel_config)
......@@ -184,43 +158,11 @@ def test_request_interface():
assumes(req, "num_tokens")
assumes(req, "kv_transfer_params", is_instance_of=(dict, NoneType))
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItem
from vllm.multimodal.inputs import MultiModalFeatureSpec
assumes(MultiModalFeatureSpec, "identifier")
assumes(MultiModalFeatureSpec, "mm_position")
# minimal case:
from vllm.multimodal.inputs import PlaceholderRange
request = Request(
request_id="test_request",
prompt_token_ids=[1, 2, 3],
sampling_params=SamplingParams(max_tokens=10),
pooling_params=None,
eos_token_id=100,
lora_request=None,
mm_features=[
MultiModalFeatureSpec(
modality="image",
identifier="0000",
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=PlaceholderRange(offset=0, length=10),
)
],
)
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils import (
extract_mm_features,
)
mm_hashes, mm_positions = extract_mm_features(request)
assert isinstance(mm_hashes, list)
assert len(mm_hashes) == 1
assert isinstance(mm_positions, list)
assert len(mm_positions) == 1
assert mm_positions[0].offset == 0
assert mm_positions[0].length == 10
def test_new_request_interface():
# protect against interface changes
......
......@@ -77,9 +77,9 @@ def _compare_directories(dir1: Path, dir2: Path) -> bool:
"https://github.com/ROCm/pytorch/issues/2822"
),
)
def test_multi_shared_storage_connector_consistency():
def test_multi_example_connector_consistency():
"""
Tests that MultiConnector with two SharedStorageConnectors saves
Tests that MultiConnector with two ExampleConnectors saves
identical KV cache data to separate storage locations.
"""
storage_1_path = Path("storage_1/")
......@@ -89,14 +89,14 @@ def test_multi_shared_storage_connector_consistency():
storage_1_path.mkdir()
storage_2_path.mkdir()
# Configure MultiConnector with two SharedStorageConnectors
# Configure MultiConnector with two ExampleConnectors
kv_transfer_config = KVTransferConfig(
kv_connector="MultiConnector",
kv_role="kv_both",
kv_connector_extra_config={
"connectors": [
{
"kv_connector": "TestSharedStorageConnector",
"kv_connector": "TestExampleConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
"shared_storage_path": str(storage_1_path),
......@@ -105,7 +105,7 @@ def test_multi_shared_storage_connector_consistency():
"kv_connector_module_path": "tests.v1.kv_connector.unit.utils",
},
{
"kv_connector": "TestSharedStorageConnector",
"kv_connector": "TestExampleConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
"shared_storage_path": str(storage_2_path),
......@@ -427,7 +427,7 @@ class TestMultiConnectorStats:
def test_build_kv_connector_stats_skips_connectors_without_custom_stats(self):
"""Test that connectors without custom stats (return None) are skipped."""
# SharedStorageConnector doesn't override build_kv_connector_stats,
# ExampleConnector doesn't override build_kv_connector_stats,
# so it returns None and should be skipped
serialized_data = {
"NixlConnector": {
......@@ -440,7 +440,7 @@ class TestMultiConnectorStats:
"num_failed_notifications": [],
}
},
"SharedStorageConnector": {"data": {"some_field": [1, 2, 3]}},
"ExampleConnector": {"data": {"some_field": [1, 2, 3]}},
}
stats = MultiConnector.build_kv_connector_stats(data=serialized_data)
......@@ -451,8 +451,8 @@ class TestMultiConnectorStats:
assert len(stats.data) == 1
assert "NixlConnector" in stats.data
assert isinstance(stats.data["NixlConnector"], NixlKVConnectorStats)
# SharedStorageConnector should be skipped (returns None)
assert "SharedStorageConnector" not in stats.data
# ExampleConnector should be skipped (returns None)
assert "ExampleConnector" not in stats.data
def test_build_kv_connector_stats_handles_malformed_data(self):
"""Test that malformed data raises appropriate errors."""
......@@ -527,13 +527,13 @@ class TestMultiConnectorStats:
)
stats2 = MultiKVConnectorStats(
data={"SharedStorageConnector": KVConnectorStats(data={"field": [1, 2]})}
data={"ExampleConnector": KVConnectorStats(data={"field": [1, 2]})}
)
result = stats1.aggregate(stats2)
assert "NixlConnector" in result.data
assert "SharedStorageConnector" in result.data
assert "ExampleConnector" in result.data
def test_reduce(self):
"""Test that reduce() correctly reduces all nested connector stats."""
......
......@@ -9,8 +9,10 @@ import textwrap
import time
import uuid
from collections import defaultdict
from unittest.mock import patch
from typing import Any
from unittest.mock import MagicMock, patch
import msgspec
import pytest
import ray
import torch
......@@ -18,6 +20,7 @@ import torch
from vllm import LLM
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.distributed.kv_transfer.kv_connector.v1 import nixl_connector
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiKVConnectorStats,
......@@ -29,13 +32,16 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorMetadata,
NixlConnectorScheduler,
NixlConnectorWorker,
NixlHandshakePayload,
NixlKVConnectorStats,
compute_nixl_compatibility_hash,
)
from vllm.distributed.kv_transfer.kv_transfer_state import (
ensure_kv_transfer_shutdown,
has_kv_transfer_group,
)
from vllm.forward_context import ForwardContext
from vllm.platforms import current_platform
from vllm.platforms.interface import Platform
from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
......@@ -317,13 +323,19 @@ def test_kv_transfer_handshake(dist_init):
}
prefill_connector.register_kv_caches(kv_caches)
# Simulate EngineCore initialization that would
# gather connector metadata from all workers, the scheduler connector
# expects metadata to be in dict[int, KVConnectorHandshakeMetadata],
# where the first key is the dp_rank, the second key is the tp_rank.
metadata = {0: prefill_connector.get_handshake_metadata()}
# Simulate EngineCore initialization that would gather connector
# metadata from all workers
metadata = prefill_connector.get_handshake_metadata()
# metadata is a NixlHandshakePayload, decode it to get NixlAgentMetadata
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
expected_agent_metadata = decoder.decode(metadata.agent_metadata_bytes)
# The scheduler connector expects metadata to be in
# dict[int, KVConnectorHandshakeMetadata], where the first key is
# the dp_rank, the second key is the tp_rank.
scheduler_connector = scheduler.get_kv_connector()
scheduler_connector.set_xfer_handshake_metadata(metadata)
scheduler_connector.set_xfer_handshake_metadata({0: metadata})
# Simulate a request that finishes prefill, which returns
# corresponding NixlConnectorMetadata for decode instance.
......@@ -362,9 +374,9 @@ def test_kv_transfer_handshake(dist_init):
)
received_metadata = mock_add_remote_agent.call_args.args
assert received_metadata[0] == expected_agent_metadata
assert received_metadata[1] == 0 # remote_tp_rank
assert received_metadata[2] == 1 # remote_tp_size
assert metadata[0] == received_metadata[0]
# Need to shutdown the background thread to release NIXL side channel port
scheduler_connector.shutdown()
......@@ -403,7 +415,6 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
device_id=0,
num_blocks=1,
block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout="HND",
......@@ -460,6 +471,7 @@ class TestNixlHandshake:
num_xfers + 6,
],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
......@@ -526,6 +538,7 @@ class TestNixlHandshake:
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": "prefill-id",
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": prefill_tp_size,
......@@ -581,6 +594,7 @@ class TestNixlHandshake:
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-id-{i}",
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
......@@ -651,7 +665,6 @@ class TestNixlHandshake:
device_id=0,
num_blocks=1,
block_lens=worker.block_len_per_layer,
attn_backend_name=worker.backend_name,
kv_cache_layout=mismatched_layout,
block_size=worker.block_size,
)
......@@ -706,7 +719,6 @@ class TestNixlHandshake:
num_blocks=1,
# prefill TP=1, decode TP=2, remote block_lens is double to local
block_lens=[i * 2 for i in worker.block_len_per_layer],
attn_backend_name=worker.backend_name,
kv_cache_layout="HND",
block_size=worker.block_size,
)
......@@ -746,6 +758,7 @@ def test_kv_connector_stats(dist_init):
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
......@@ -1099,7 +1112,26 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
llm.llm_engine.engine_core.shutdown()
@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "TRITON_ATTN"])
@pytest.mark.parametrize(
"attn_backend",
[
pytest.param(
"FLASH_ATTN",
marks=pytest.mark.skipif(
current_platform.is_rocm(),
reason="Attention backend FLASH_ATTN is not supported on ROCm",
),
),
pytest.param(
"ROCM_ATTN",
marks=pytest.mark.skipif(
not current_platform.is_rocm(),
reason="Attention backend ROCM_ATTN is only supported on ROCm",
),
),
"TRITON_ATTN",
],
)
def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
"""
Test that register_kv_caches() properly calls nixl_wrapper methods with
......@@ -1121,6 +1153,10 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
backend_cls = FlashAttentionBackend
elif attn_backend == "ROCM_ATTN":
from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend
backend_cls = RocmAttentionBackend
else: # TRITON_ATTN
from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend
......@@ -1139,25 +1175,43 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
}
# Store tensor info for validation
expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel()
test_shape = backend_cls.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1
if is_blocks_first:
expected_tensor_size = shared_tensor.element_size() * shared_tensor.numel()
expected_base_addrs = [
shared_tensor.data_ptr(),
unique_tensor.data_ptr(),
]
expected_num_entries = 2
else:
expected_tensor_size = (
shared_tensor[0].element_size() * shared_tensor[0].numel()
)
expected_base_addrs = [
shared_tensor[0].data_ptr(),
shared_tensor[1].data_ptr(),
unique_tensor[0].data_ptr(),
unique_tensor[1].data_ptr(),
]
expected_num_entries = 4
nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector"
with (
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
) as mock_nixl_wrapper,
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"
) as mock_thread,
): # noqa: E501
patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper,
patch(f"{nixl_module}.threading.Event"),
patch(f"{nixl_module}.threading.Thread") as mock_thread,
patch(f"{nixl_module}.get_attn_backend") as mock_get_attn_backend,
):
# Ensure get_attn_backend returns the correct value due to
# _cached_get_attn_backend returning the backend from previous
# test run if not mocking.
mock_get_attn_backend.return_value = backend_cls
# Create connector
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
......@@ -1168,6 +1222,9 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
mock_wrapper_instance = mock_nixl_wrapper.return_value
connector.connector_worker.nixl_wrapper = mock_wrapper_instance
# Appease NixlHandshakePayload encoding with some bytes
mock_wrapper_instance.get_agent_metadata.return_value = b"fake_agent_metadata"
# Reassure the shutdown() check that the thread is terminated
mock_thread.return_value.is_alive.return_value = False
......@@ -1177,7 +1234,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
# Verify get_reg_descs was called with caches_data
assert mock_wrapper_instance.get_reg_descs.called
caches_data, _ = mock_wrapper_instance.get_reg_descs.call_args[0]
assert len(caches_data) == 4
assert len(caches_data) == expected_num_entries
for i, cache_entry in enumerate(caches_data):
base_addr, size, _tp_rank, _ = cache_entry
......@@ -1199,7 +1256,12 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}"
)
expected_block_len = expected_tensor_size // 2
num_blocks = 2
if is_blocks_first:
expected_block_len = expected_tensor_size // num_blocks // 2
else:
expected_block_len = expected_tensor_size // num_blocks
for i, block_entry in enumerate(blocks_data):
block_start_addr, block_len, tp_rank = block_entry
assert block_len == expected_block_len, (
......@@ -1296,7 +1358,7 @@ def test_shutdown_cleans_up_resources(dist_init):
patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent,
patch.object(nixl_wrapper, "deregister_memory") as mock_dereg,
):
worker._recving_transfers = {"req1": [(123, time.perf_counter())]}
worker._recving_transfers = {"req1": [123]}
worker.src_xfer_side_handle = 456
worker.dst_xfer_side_handles = {"engine1": 789}
worker._remote_agents = {"engine1": {0: "agent1"}}
......@@ -1459,6 +1521,7 @@ def test_handshake_failure_returns_finished(dist_init):
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
......@@ -1508,6 +1571,7 @@ def test_transfer_setup_failure_returns_finished(dist_init):
kv_transfer_params={
"remote_block_ids": [10, 11, 12],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
......@@ -1534,3 +1598,194 @@ def test_transfer_setup_failure_returns_finished(dist_init):
# ensure request appears in get_finished
_, done_recving = connector.get_finished(finished_req_ids=set())
assert request_id in done_recving
@pytest.mark.parametrize(
"mismatch_type,config_overrides,version_override,should_fail,enforce_handshake_compat",
[
("vllm_version", {}, {"vllm_version": "0.6.1"}, True, True),
("nixl_connector_version", {}, {"connector_version": 37}, True, True),
("model_name", {"model": "facebook/opt-350m"}, {}, True, True),
("dtype", {"dtype": "bfloat16"}, {}, True, True),
("cache_dtype", {"cache_dtype": "fp8"}, {}, True, True),
("num_kv_heads", {"hf_overrides": {"num_key_value_heads": 8}}, {}, True, True),
(
"num_hidden_layers",
{"hf_overrides": {"num_hidden_layers": 24}},
{},
True,
True,
),
("hidden_size", {"hf_overrides": {"hidden_size": 1536}}, {}, True, True),
("block_size", {"block_size": 8}, {}, False, True),
("matching_config", {}, {}, False, True),
("escape_hatch", {"model": "facebook/opt-350m"}, {}, False, False),
],
)
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_compatibility_hash_validation(
dist_init,
mismatch_type,
config_overrides,
version_override,
should_fail,
enforce_handshake_compat,
):
"""
Test NIXL compatibility hash validation during handshake.
Parameters:
mismatch_type: description of what is being tested
config_overrides: dict of config to override for the remote instance
version_override: version dict e.g. {"vllm_version": "0.6.1"}
should_fail: whether the handshake should fail
enforce_handshake_compat: whether to enforce compatibility checking
"""
local_vllm_config = create_vllm_config(
model="facebook/opt-125m",
block_size=16,
kv_connector_extra_config={
"enforce_handshake_compat": enforce_handshake_compat
},
)
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
decode_worker = decode_connector.connector_worker
remote_config_params: dict[str, Any] = {
"model": "facebook/opt-125m",
"block_size": 16,
**config_overrides,
}
remote_vllm_config = create_vllm_config(**remote_config_params)
with contextlib.ExitStack() as stack:
if "vllm_version" in version_override:
stack.enter_context(
patch("vllm.__version__", version_override["vllm_version"])
)
elif "connector_version" in version_override:
stack.enter_context(
patch.object(
nixl_connector,
"NIXL_CONNECTOR_VERSION",
version_override["connector_version"],
)
)
remote_hash = compute_nixl_compatibility_hash(
remote_vllm_config, decode_worker.backend_name
)
prefill_block_size = config_overrides.get("block_size", 16)
prefill_metadata = NixlAgentMetadata(
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=0,
num_blocks=1,
block_lens=[4096 * prefill_block_size], # slot_size * block_size
kv_cache_layout="HND",
block_size=prefill_block_size,
)
handshake_payload = NixlHandshakePayload(
compatibility_hash=remote_hash,
agent_metadata_bytes=msgspec.msgpack.encode(prefill_metadata),
)
# Mock ZMQ socket to return our handshake payload
mock_socket = MagicMock()
mock_socket.recv.return_value = msgspec.msgpack.encode(handshake_payload)
# Mock add_remote_agent to avoid actual NIXL operations
# Patch zmq_ctx to return our mock socket
with (
patch.object(decode_worker, "add_remote_agent", return_value="fake_agent"),
patch.object(nixl_connector, "zmq_ctx") as mock_zmq_ctx,
):
mock_zmq_ctx.return_value.__enter__.return_value = mock_socket
if should_fail:
with pytest.raises(RuntimeError, match="compatibility hash mismatch"):
decode_worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=1,
expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
)
else:
result = decode_worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=1,
expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
)
# Verify handshake returned agent mapping
assert isinstance(result, dict)
assert len(result) == 1
@pytest.mark.parametrize(
"error_scenario",
[
"handshake_decode_error",
"handshake_validation_error",
"metadata_decode_error",
"metadata_validation_error",
],
)
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_handshake_decode_errors(dist_init, error_scenario):
"""
Test that msgspec decode errors are properly handled during handshake.
Tests both DecodeError and ValidationError for both decoders:
- NixlHandshakePayload decoder
- NixlAgentMetadata decoder
"""
local_vllm_config = create_vllm_config(
model="facebook/opt-125m",
block_size=16,
)
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
decode_worker = decode_connector.connector_worker
if error_scenario == "handshake_decode_error":
msg_bytes = b"this is not valid msgpack data"
elif error_scenario == "handshake_validation_error":
msg_bytes = msgspec.msgpack.encode({"wrong_field": "value"})
elif error_scenario == "metadata_decode_error":
valid_handshake = NixlHandshakePayload(
compatibility_hash=decode_worker.compat_hash,
agent_metadata_bytes=b"invalid msgpack for metadata",
)
msg_bytes = msgspec.msgpack.encode(valid_handshake)
elif error_scenario == "metadata_validation_error":
valid_handshake = NixlHandshakePayload(
compatibility_hash=decode_worker.compat_hash,
agent_metadata_bytes=msgspec.msgpack.encode({"missing": "fields"}),
)
msg_bytes = msgspec.msgpack.encode(valid_handshake)
else:
raise AssertionError(f"{error_scenario} not a valid scenario")
mock_socket = MagicMock()
mock_socket.recv.return_value = msg_bytes
with (
patch.object(decode_worker, "add_remote_agent", return_value="fake_agent"),
patch.object(nixl_connector, "zmq_ctx") as mock_zmq_ctx,
):
mock_zmq_ctx.return_value.__enter__.return_value = mock_socket
with pytest.raises(RuntimeError):
decode_worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=1,
expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
)
......@@ -24,8 +24,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
SharedStorageConnector,
from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( # noqa
ExampleConnector,
)
from vllm.utils.hashing import sha256
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
......@@ -90,13 +90,25 @@ def create_vllm_config(
max_model_len: int = 10000,
enable_chunked_prefill: bool = True,
enable_permute_local_kv: bool = False,
kv_connector_extra_config: dict[str, Any] | None = None,
dtype: str = "float16",
cache_dtype: str = "auto",
hf_overrides: dict[str, Any] | None = None,
) -> VllmConfig:
"""Initialize VllmConfig For Testing."""
model_config = ModelConfig(
model=model,
trust_remote_code=True,
dtype="float16",
dtype=dtype,
seed=42,
hf_overrides=hf_overrides or {},
)
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
enable_chunked_prefill=enable_chunked_prefill,
is_encoder_decoder=model_config.is_encoder_decoder,
)
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
......@@ -110,13 +122,14 @@ def create_vllm_config(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
cache_dtype=cache_dtype,
enable_prefix_caching=True,
)
kv_transfer_config = KVTransferConfig(
kv_connector="NixlConnector",
kv_role="kv_both",
enable_permute_local_kv=enable_permute_local_kv,
kv_connector_extra_config=kv_connector_extra_config or {},
)
return VllmConfig(
scheduler_config=scheduler_config,
......@@ -188,6 +201,7 @@ def create_request(
do_remote_prefill=True,
do_remote_decode=False,
remote_engine_id="my-engine-id",
remote_request_id=f"prefill-{request_id}",
remote_block_ids=list(range(num_remote_blocks)),
remote_host="my-host",
remote_port=1234,
......@@ -257,10 +271,10 @@ def create_model_runner_output(
)
class TestSharedStorageConnector(SharedStorageConnector):
class TestExampleConnector(ExampleConnector):
def __init__(self, config: VllmConfig, role, kv_cache_config):
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
self._connector = SharedStorageConnector(config, role)
self._connector = ExampleConnector(config, role)
self.call_record: dict[str, int] = defaultdict(int)
# Use a unique temp file per connector
self._event_file = (
......@@ -387,7 +401,7 @@ class MockKVConnector(KVConnectorBase_V1):
KVConnectorFactory.register_connector(
"TestSharedStorageConnector", __name__, TestSharedStorageConnector.__name__
"TestExampleConnector", __name__, TestExampleConnector.__name__
)
KVConnectorFactory.register_connector(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment