Commit a3f8d5dd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori

parents 8d75f22e f34eca5f
...@@ -16,6 +16,16 @@ from vllm.platforms import current_platform ...@@ -16,6 +16,16 @@ from vllm.platforms import current_platform
MTP_SIMILARITY_RATE = 0.8 MTP_SIMILARITY_RATE = 0.8
def _skip_if_insufficient_gpus_for_tp(tp_size: int):
"""Skip test if available GPUs < tp_size on ROCm."""
if current_platform.is_rocm():
available_gpus = torch.cuda.device_count()
if available_gpus < tp_size:
pytest.skip(
f"Test requires {tp_size} GPUs, but only {available_gpus} available"
)
def get_test_prompts(mm_enabled: bool): def get_test_prompts(mm_enabled: bool):
prompt_types = ["repeat", "sentence"] prompt_types = ["repeat", "sentence"]
if mm_enabled: if mm_enabled:
...@@ -280,9 +290,20 @@ def test_speculators_model_integration( ...@@ -280,9 +290,20 @@ def test_speculators_model_integration(
@pytest.mark.parametrize( @pytest.mark.parametrize(
["model_setup", "mm_enabled", "enable_chunked_prefill"], ["model_setup", "mm_enabled", "enable_chunked_prefill", "model_impl"],
[ [
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False), (
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
False,
False,
"auto",
),
(
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
False,
False,
"transformers",
),
pytest.param( pytest.param(
( (
"eagle3", "eagle3",
...@@ -292,6 +313,7 @@ def test_speculators_model_integration( ...@@ -292,6 +313,7 @@ def test_speculators_model_integration(
), ),
False, False,
False, False,
"auto",
marks=pytest.mark.skip( marks=pytest.mark.skip(
reason="architecture of its eagle3 is LlamaForCausalLMEagle3" reason="architecture of its eagle3 is LlamaForCausalLMEagle3"
), ),
...@@ -305,6 +327,7 @@ def test_speculators_model_integration( ...@@ -305,6 +327,7 @@ def test_speculators_model_integration(
), ),
False, False,
False, False,
"auto",
marks=pytest.mark.skip( marks=pytest.mark.skip(
reason="Skipping due to its head_dim not being a a multiple of 32" reason="Skipping due to its head_dim not being a a multiple of 32"
), ),
...@@ -318,6 +341,7 @@ def test_speculators_model_integration( ...@@ -318,6 +341,7 @@ def test_speculators_model_integration(
), ),
False, False,
True, True,
"auto",
marks=large_gpu_mark(min_gb=40), marks=large_gpu_mark(min_gb=40),
), # works on 4x H100 ), # works on 4x H100
( (
...@@ -329,6 +353,7 @@ def test_speculators_model_integration( ...@@ -329,6 +353,7 @@ def test_speculators_model_integration(
), ),
False, False,
False, False,
"auto",
), ),
pytest.param( pytest.param(
( (
...@@ -339,6 +364,7 @@ def test_speculators_model_integration( ...@@ -339,6 +364,7 @@ def test_speculators_model_integration(
), ),
False, False,
False, False,
"auto",
marks=large_gpu_mark(min_gb=80), marks=large_gpu_mark(min_gb=80),
), # works on 4x H100 ), # works on 4x H100
pytest.param( pytest.param(
...@@ -350,6 +376,7 @@ def test_speculators_model_integration( ...@@ -350,6 +376,7 @@ def test_speculators_model_integration(
), ),
True, True,
True, True,
"auto",
marks=large_gpu_mark(min_gb=80), marks=large_gpu_mark(min_gb=80),
), # works on 4x H100 ), # works on 4x H100
( (
...@@ -361,10 +388,12 @@ def test_speculators_model_integration( ...@@ -361,10 +388,12 @@ def test_speculators_model_integration(
), ),
False, False,
False, False,
"auto",
), ),
], ],
ids=[ ids=[
"qwen3_eagle3", "qwen3_eagle3",
"qwen3_eagle3-transformers",
"qwen3_vl_eagle3", "qwen3_vl_eagle3",
"qwen2_5_vl_eagle3", "qwen2_5_vl_eagle3",
"llama3_eagle", "llama3_eagle",
...@@ -381,6 +410,7 @@ def test_eagle_correctness( ...@@ -381,6 +410,7 @@ def test_eagle_correctness(
model_setup: tuple[str, str, str, int], model_setup: tuple[str, str, str, int],
mm_enabled: bool, mm_enabled: bool,
enable_chunked_prefill: bool, enable_chunked_prefill: bool,
model_impl: str,
attn_backend: str, attn_backend: str,
): ):
if attn_backend == "TREE_ATTN": if attn_backend == "TREE_ATTN":
...@@ -389,6 +419,17 @@ def test_eagle_correctness( ...@@ -389,6 +419,17 @@ def test_eagle_correctness(
"TREE_ATTN is flaky in the test disable for now until it can be " "TREE_ATTN is flaky in the test disable for now until it can be "
"resolved (see https://github.com/vllm-project/vllm/issues/22922)" "resolved (see https://github.com/vllm-project/vllm/issues/22922)"
) )
if model_impl == "transformers":
import transformers
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("5.0.0.dev")
if installed < required:
pytest.skip(
"Eagle3 with the Transformers modeling backend requires "
f"transformers>={required}, but got {installed}"
)
# Generate test prompts inside the function instead of using fixture # Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled) test_prompts = get_test_prompts(mm_enabled)
...@@ -424,6 +465,8 @@ def test_eagle_correctness( ...@@ -424,6 +465,8 @@ def test_eagle_correctness(
m.setenv("VLLM_ROCM_USE_AITER", "1") m.setenv("VLLM_ROCM_USE_AITER", "1")
method, model_name, spec_model_name, tp_size = model_setup method, model_name, spec_model_name, tp_size = model_setup
_skip_if_insufficient_gpus_for_tp(tp_size)
max_model_len = 2048 max_model_len = 2048
max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len
...@@ -448,6 +491,7 @@ def test_eagle_correctness( ...@@ -448,6 +491,7 @@ def test_eagle_correctness(
max_model_len=max_model_len, max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill, enable_chunked_prefill=enable_chunked_prefill,
model_impl=model_impl,
) )
spec_outputs = spec_llm.chat(test_prompts, sampling_config) spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0 matches = 0
...@@ -493,6 +537,7 @@ def test_mtp_correctness( ...@@ -493,6 +537,7 @@ def test_mtp_correctness(
m.setenv("VLLM_MLA_DISABLE", "1") m.setenv("VLLM_MLA_DISABLE", "1")
method, model_name, tp_size = model_setup method, model_name, tp_size = model_setup
_skip_if_insufficient_gpus_for_tp(tp_size)
ref_llm = LLM( ref_llm = LLM(
model=model_name, model=model_name,
......
...@@ -38,7 +38,7 @@ class MockRequest: ...@@ -38,7 +38,7 @@ class MockRequest:
) )
self.mm_features.append(feature) self.mm_features.append(feature)
def get_num_encoder_tokens(self, input_id: int) -> int: def get_num_encoder_embeds(self, input_id: int) -> int:
assert input_id < len(self._token_counts) assert input_id < len(self._token_counts)
return self._token_counts[input_id] return self._token_counts[input_id]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.v1.core.kv_cache_utils import check_enough_kv_cache_memory
from vllm.v1.kv_cache_interface import FullAttentionSpec
def test_kv_cache_oom_no_memory():
from unittest.mock import MagicMock
config = MagicMock()
config.model_config.max_model_len = 2048
spec = {
"layer_0": FullAttentionSpec(
block_size=16,
num_kv_heads=8,
head_size=128,
dtype="float16",
)
}
with pytest.raises(ValueError):
check_enough_kv_cache_memory(config, spec, 0)
def test_kv_cache_oom_insufficient_memory(monkeypatch):
from unittest.mock import MagicMock
config = MagicMock()
config.model_config.max_model_len = 2048
config.cache_config.block_size = 16
config.parallel_config.tensor_parallel_size = 1
config.parallel_config.pipeline_parallel_size = 1
config.parallel_config.decode_context_parallel_size = 1
monkeypatch.setattr(
"vllm.v1.core.kv_cache_utils.max_memory_usage_bytes",
lambda c, s: 100 * 1024**3, # 100 GiB
)
spec = {
"layer_0": FullAttentionSpec(
block_size=16,
num_kv_heads=8,
head_size=128,
dtype="float16",
)
}
with pytest.raises(ValueError):
check_enough_kv_cache_memory(config, spec, 1024**3) # 1 GiB
...@@ -76,6 +76,8 @@ def sample_json_schema(): ...@@ -76,6 +76,8 @@ def sample_json_schema():
}, },
"required": ["name", "age", "skills", "grade", "email", "work_history"], "required": ["name", "age", "skills", "grade", "email", "work_history"],
"additionalProperties": False, "additionalProperties": False,
"minProperties": 1,
"maxProperties": 10,
} }
...@@ -96,6 +98,9 @@ def unsupported_json_schema(): ...@@ -96,6 +98,9 @@ def unsupported_json_schema():
}, },
"required": ["score", "tags"], "required": ["score", "tags"],
"additionalProperties": False, "additionalProperties": False,
"patternProperties": {
"^score$": {"type": "integer"},
},
} }
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
test that invalid blocks are evicted from prefix cache to prevent pollution.
verifies that when sync-loading fails, invalid blocks are removed from the
prefix cache hash table so future requests cannot match and reuse corrupted data.
"""
from collections.abc import Callable
from unittest.mock import Mock
import pytest
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import Request, RequestStatus
from .utils import (
create_model_runner_output,
create_request,
create_scheduler,
create_vllm_config,
)
pytestmark = pytest.mark.cpu_test
def _make_get_num_new_matched_tokens(
req_num_new_matched_tokens: dict[str, int],
async_load: bool,
) -> Callable[[Request, int], tuple[int, bool]]:
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
value = req_num_new_matched_tokens.get(request.request_id, 0)
return value, async_load
return get_num_new_matched_tokens
@pytest.fixture
def fail_scheduler():
"""scheduler with kv_load_failure_policy='fail'"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
return create_scheduler(vllm_config)
def test_invalid_blocks_evicted_prevents_cache_pollution(
fail_scheduler: Scheduler,
):
"""
verify invalid blocks are evicted to prevent future cache hits.
scenario:
1. request 1 loads externally-computed blocks (sync mode)
2. some blocks fail to load and are marked invalid
3. with fail policy, invalid blocks should be evicted from prefix cache
4. request is marked as FINISHED_ERROR
"""
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * fail_scheduler.block_size
)
# request 1: will have invalid blocks
request1 = create_request(num_tokens=num_prompt_tokens, request_id=1)
fail_scheduler.add_request(request=request1)
req_num_new_matched_tokens = {
request1.request_id: num_external_computed_tokens,
}
# mock connector indicating sync load
fail_scheduler.connector = Mock()
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
)
fail_scheduler.connector.request_finished.return_value = (False, None)
fail_scheduler.connector.take_events.return_value = ()
scheduler_output = fail_scheduler.schedule()
# request should be running with sync KV load
assert len(fail_scheduler.running) == 1
assert request1.status == RequestStatus.RUNNING
# get allocated block IDs
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
invalid_block_id = req_block_ids[invalid_block_idx]
invalid_block_ids = {invalid_block_id}
# get the block object to verify eviction later
block = fail_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
# cache the blocks to simulate they've been computed and cached
# (in real scenario blocks would be cached after compute)
fail_scheduler.kv_cache_manager.cache_blocks(request1, num_external_computed_tokens)
# verify block has a hash (is cached) before reporting invalid blocks
assert block.block_hash is not None, (
f"block {invalid_block_id} should be cached (have a hash) before "
f"eviction test, but hash is None"
)
# report invalid blocks
model_runner_output = create_model_runner_output(
[request1],
invalid_block_ids=invalid_block_ids,
use_eos=False,
)
fail_scheduler.update_from_output(scheduler_output, model_runner_output)
# verify request finished with error (fail policy)
assert request1.status == RequestStatus.FINISHED_ERROR
# critical assertion: invalid block and all subsequent blocks should be evicted
# all blocks from invalid_block_idx onwards become invalid since they were
# computed based on the failed block
for idx in range(invalid_block_idx, len(req_block_ids)):
block_id = req_block_ids[idx]
block_obj = fail_scheduler.kv_cache_manager.block_pool.blocks[block_id]
assert block_obj.block_hash is None, (
f"block {block_id} at index {idx} should have been evicted "
f"(hash reset to None), but hash is {block_obj.block_hash}. "
f"All blocks from index {invalid_block_idx} onwards should be evicted "
f"since they depend on the invalid block at index {invalid_block_idx}."
)
# verify cache contains exactly the valid blocks (before first affected block)
# and none of the invalid blocks (from first affected block onwards)
# valid blocks: all blocks before invalid_block_idx should be cached
for idx in range(invalid_block_idx):
block_id = req_block_ids[idx]
block_obj = fail_scheduler.kv_cache_manager.block_pool.blocks[block_id]
assert block_obj.block_hash is not None, (
f"valid block {block_id} at index {idx} should still be cached "
f"(have a hash), but hash is None. Only blocks from index "
f"{invalid_block_idx} onwards should be evicted."
)
# invalid blocks: verify they're not in the cached_block_hash_to_block map
cached_blocks = (
fail_scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block
)
cached_block_ids = {
b.block_id
for blocks_val in cached_blocks._cache.values()
for b in (
[blocks_val] if not isinstance(blocks_val, dict) else blocks_val.values()
)
}
for idx in range(invalid_block_idx, len(req_block_ids)):
block_id = req_block_ids[idx]
assert block_id not in cached_block_ids, (
f"invalid block {block_id} at index {idx} should not be in cache hash table"
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from unittest.mock import Mock
import pytest
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import FinishReason, Request, RequestStatus
from .utils import (
create_model_runner_output,
create_request,
create_scheduler,
create_vllm_config,
)
pytestmark = pytest.mark.cpu_test
def _make_get_num_new_matched_tokens(
req_num_new_matched_tokens: dict[str, int],
async_load: bool,
) -> Callable[[Request, int], tuple[int, bool]]:
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
value = req_num_new_matched_tokens.get(request.request_id, 0)
return value, async_load
return get_num_new_matched_tokens
@pytest.fixture
def fail_scheduler():
"""scheduler with kv_load_failure_policy='fail'"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
return create_scheduler(vllm_config)
def test_error_propagation_sync_load(fail_scheduler: Scheduler):
"""test invalid_block_ids with fail policy -> FINISHED_ERROR (sync load)"""
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * fail_scheduler.block_size
)
request = create_request(num_tokens=num_prompt_tokens)
fail_scheduler.add_request(request=request)
req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}
fail_scheduler.connector = Mock()
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
)
fail_scheduler.connector.request_finished.return_value = (False, None)
fail_scheduler.connector.take_events.return_value = ()
scheduler_output = fail_scheduler.schedule()
assert len(fail_scheduler.running) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
assert fail_scheduler.connector.get_num_new_matched_tokens.call_count == 1
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
invalid_block_ids = {req_block_ids[invalid_block_idx]}
model_runner_output = create_model_runner_output(
[request],
invalid_block_ids=invalid_block_ids,
use_eos=True,
)
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
assert request.status == RequestStatus.FINISHED_ERROR
assert request.get_finished_reason() == FinishReason.ERROR
assert len(outputs) == 1
engine_outputs = next(iter(outputs.values()))
assert len(engine_outputs.outputs) == 1
output = engine_outputs.outputs[0]
assert output.request_id == request.request_id
assert output.finish_reason == FinishReason.ERROR
assert len(fail_scheduler.running) == 0
def test_error_propagation_async_load(fail_scheduler: Scheduler):
"""test invalid_block_ids with fail policy -> FINISHED_ERROR (async load)"""
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * fail_scheduler.block_size
)
request = create_request(num_tokens=num_prompt_tokens)
fail_scheduler.add_request(request=request)
req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}
fail_scheduler.connector = Mock()
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, True)
)
fail_scheduler.connector.request_finished.return_value = (False, None)
fail_scheduler.connector.take_events.return_value = ()
scheduler_output = fail_scheduler.schedule()
assert len(fail_scheduler.waiting) == 1
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == 0
(req_block_ids,) = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id)
invalid_block_ids = {req_block_ids[invalid_block_idx]}
model_runner_output = create_model_runner_output(
reqs=[],
finished_recving=set(),
invalid_block_ids=invalid_block_ids,
use_eos=True,
)
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
assert request.status == RequestStatus.FINISHED_ERROR
assert request.get_finished_reason() == FinishReason.ERROR
assert len(outputs) == 1
engine_outputs = next(iter(outputs.values()))
assert len(engine_outputs.outputs) == 1
output = engine_outputs.outputs[0]
assert output.request_id == request.request_id
assert output.finish_reason == FinishReason.ERROR
assert len(fail_scheduler.waiting) == 0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for correctness in invalid block handling.
These tests verify correct behavior in three scenarios:
1. Sync recompute case: Blocks should not be freed for running requests
that need to recompute invalid blocks
2. Sync fail case: Invalid blocks must be evicted from cache when request fails
3. Async recompute case: Invalid blocks should not be cached after transfer
"""
from collections.abc import Callable
from unittest.mock import Mock
import pytest
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import FinishReason, Request, RequestStatus
from .utils import (
create_model_runner_output,
create_request,
create_scheduler,
create_vllm_config,
)
pytestmark = pytest.mark.cpu_test
def _make_get_num_new_matched_tokens(
req_num_new_matched_tokens: dict[str, int],
async_load: bool,
) -> Callable[[Request, int], tuple[int, bool]]:
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
value = req_num_new_matched_tokens.get(request.request_id, 0)
return value, async_load
return get_num_new_matched_tokens
@pytest.fixture
def fail_scheduler():
"""scheduler with kv_load_failure_policy='fail'"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
return create_scheduler(vllm_config)
@pytest.fixture
def recompute_scheduler():
"""scheduler with kv_load_failure_policy='recompute'"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_load_failure_policy = "recompute"
return create_scheduler(vllm_config)
def test_sync_recompute_blocks_not_freed_for_running_requests(
recompute_scheduler: Scheduler,
):
"""
Test sync recompute case - blocks must not be freed for running requests.
When a running request has invalid blocks and retry_policy is 'recompute':
1. Request should remain in RUNNING state
2. num_computed_tokens should be truncated to invalid block boundary
3. Blocks should NOT be freed (request still needs them for recomputation)
4. Request should remain in scheduler.requests and scheduler.running
"""
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * recompute_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * recompute_scheduler.block_size
)
request = create_request(num_tokens=num_prompt_tokens)
recompute_scheduler.add_request(request=request)
req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}
# mock connector indicating sync load
recompute_scheduler.connector = Mock()
recompute_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
)
recompute_scheduler.connector.request_finished.return_value = (False, None)
recompute_scheduler.connector.take_events.return_value = ()
scheduler_output = recompute_scheduler.schedule()
# request should be running with sync KV load
assert len(recompute_scheduler.running) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
assert request.status == RequestStatus.RUNNING
# get the allocated block IDs before invalid blocks are reported
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
invalid_block_ids = {req_block_ids[invalid_block_idx]}
# store original num_computed_tokens for comparison
original_num_computed_tokens = request.num_computed_tokens
model_runner_output = create_model_runner_output(
[request],
invalid_block_ids=invalid_block_ids,
use_eos=False, # not finished - should continue running
)
outputs = recompute_scheduler.update_from_output(
scheduler_output, model_runner_output
)
# critical assertions for recompute case:
# 1. request should still be RUNNING (not finished, not aborted)
assert request.status == RequestStatus.RUNNING, (
f"Request should remain RUNNING for recompute, got {request.status}"
)
# 2. num_computed_tokens should be truncated to first invalid block
expected_truncated_tokens = invalid_block_idx * recompute_scheduler.block_size
assert request.num_computed_tokens == expected_truncated_tokens, (
f"num_computed_tokens should be truncated to {expected_truncated_tokens}, "
f"got {request.num_computed_tokens}"
)
assert request.num_computed_tokens < original_num_computed_tokens, (
"num_computed_tokens should be reduced after invalid block detection"
)
# 3. no output should be generated (request is still running)
# the request should be skipped in the output loop
assert len(outputs) == 0 or request.request_id not in [
out.request_id for outs in outputs.values() for out in outs.outputs
], "No output should be generated for recompute requests"
# 4. request should still be in running queue
assert request in recompute_scheduler.running, (
"Request should remain in running queue for recomputation"
)
# 5. request should still be in scheduler.requests (not deleted)
assert request.request_id in recompute_scheduler.requests, (
"Request should not be deleted from scheduler.requests"
)
# 6. blocks should NOT be freed - verify blocks are still allocated
try:
allocated_blocks = recompute_scheduler.kv_cache_manager.get_block_ids(
request.request_id
)
assert allocated_blocks is not None
assert len(allocated_blocks[0]) > 0, (
"Blocks should still be allocated for recomputation"
)
except KeyError:
pytest.fail(
"Blocks were freed incorrectly! Running requests need their blocks "
"to recompute invalid portions."
)
# 7. verify request can be rescheduled in next step
scheduler_output_2 = recompute_scheduler.schedule()
# request should appear in the new schedule to recompute invalid blocks
scheduled_req_ids = [
req.request_id for req in scheduler_output_2.scheduled_new_reqs
]
if scheduler_output_2.num_scheduled_tokens:
scheduled_req_ids.extend(scheduler_output_2.num_scheduled_tokens.keys())
assert (
request.request_id in scheduled_req_ids or len(recompute_scheduler.running) > 0
), "Request should be reschedulable for recomputation"
def test_sync_fail_invalid_blocks_evicted(fail_scheduler: Scheduler):
"""
Test sync fail case - invalid blocks must be evicted from cache.
When a request fails with policy='fail' and has invalid blocks from sync loading:
1. Request should be finished with FINISHED_ERROR
2. Invalid blocks should be evicted from the KV cache
3. Valid blocks (if shared) should remain in cache
4. Future requests should not reuse the invalid blocks
This test verifies that invalid blocks are properly evicted to prevent
cache corruption and reuse of invalid data.
"""
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * fail_scheduler.block_size
)
request = create_request(num_tokens=num_prompt_tokens)
fail_scheduler.add_request(request=request)
req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}
# mock connector indicating sync load
fail_scheduler.connector = Mock()
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
)
fail_scheduler.connector.request_finished.return_value = (False, None)
fail_scheduler.connector.take_events.return_value = ()
scheduler_output = fail_scheduler.schedule()
# request should be running with sync KV load
assert len(fail_scheduler.running) == 1
assert request.status == RequestStatus.RUNNING
# get allocated block IDs
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
invalid_block_id = req_block_ids[invalid_block_idx]
invalid_block_ids = {invalid_block_id}
# verify the block is in the block pool before we report it as invalid
block = fail_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
assert block is not None
# report invalid blocks - request should fail
model_runner_output = create_model_runner_output(
[request],
invalid_block_ids=invalid_block_ids,
use_eos=True,
)
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
# verify request is finished with error
assert request.status == RequestStatus.FINISHED_ERROR
assert request.get_finished_reason() == FinishReason.ERROR
# verify output is generated
assert len(outputs) == 1
engine_outputs = next(iter(outputs.values()))
assert len(engine_outputs.outputs) == 1
output = engine_outputs.outputs[0]
assert output.request_id == request.request_id
assert output.finish_reason == FinishReason.ERROR
# verify the request was removed from scheduler
assert request.request_id not in fail_scheduler.requests
assert len(fail_scheduler.running) == 0
# critical: verify invalid block was actually freed from cache
# this is the key assertion - the invalid block should no longer be
# tracked by the KV cache manager for this request
# if it's still there, a future request could reuse the invalid data
try:
block_ids = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id)
# if we get here, check if blocks were actually freed
if block_ids is not None and len(block_ids[0]) > 0:
pytest.fail(
f"Invalid blocks still tracked for finished request! "
f"Request {request.request_id} should have been freed but "
f"still has {len(block_ids[0])} blocks allocated."
)
# blocks list exists but is empty - this is fine, they were freed
except KeyError:
# expected - request completely removed from tracking
pass
# critical: verify invalid block was evicted from prefix cache
# the block should no longer have a hash (hash is reset on eviction)
assert block.block_hash is None, (
f"Invalid block {invalid_block_id} should have been evicted from cache "
f"(hash should be None), but hash is still {block.block_hash}"
)
def test_async_recompute_blocks_not_cached_when_invalid(
recompute_scheduler: Scheduler,
):
"""
Test async recompute case - invalid blocks not cached after transfer.
When async KV loading has invalid blocks and retry_policy is 'recompute':
1. Blocks are allocated but not cached yet
2. When async transfer completes, only valid blocks should be cached
3. Invalid blocks should never enter the prefix cache
This test verifies correctness, the failed_recving_kv_req_ids protection
ensures only valid blocks are cached when the transfer completes, and we
only evict blocks from cache that are already hashed in the block table.
"""
from unittest.mock import patch
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * recompute_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * recompute_scheduler.block_size
)
request = create_request(num_tokens=num_prompt_tokens)
recompute_scheduler.add_request(request=request)
req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}
# mock connector indicating async load
recompute_scheduler.connector = Mock()
recompute_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, True)
)
recompute_scheduler.connector.request_finished.return_value = (False, None)
recompute_scheduler.connector.take_events.return_value = ()
scheduler_output = recompute_scheduler.schedule()
# request should be waiting for remote KVs
assert len(recompute_scheduler.waiting) == 1
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == 0
# get the allocated block IDs
(req_block_ids,) = recompute_scheduler.kv_cache_manager.get_block_ids(
request.request_id
)
invalid_block_id = req_block_ids[invalid_block_idx]
invalid_block_ids = {invalid_block_id}
# get the block object to verify it's not cached yet and stays uncached
block = recompute_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
# verify block has no hash before invalid blocks are reported
assert block.block_hash is None, (
"Async loading blocks should not be cached yet (no hash)"
)
# report invalid blocks (transfer not finished yet)
model_runner_output = create_model_runner_output(
reqs=[],
finished_recving=None, # transfer NOT finished
invalid_block_ids=invalid_block_ids,
use_eos=False,
)
# critical: spy on evict_blocks to verify it's NOT called for async blocks
original_evict_blocks = recompute_scheduler.kv_cache_manager.evict_blocks
evict_blocks_calls = []
def evict_blocks_spy(block_ids):
evict_blocks_calls.append(set(block_ids))
return original_evict_blocks(block_ids)
with patch.object(
recompute_scheduler.kv_cache_manager, "evict_blocks", evict_blocks_spy
):
recompute_scheduler.update_from_output(scheduler_output, model_runner_output)
# verify evict_blocks was NOT called (async blocks excluded from eviction)
assert len(evict_blocks_calls) == 0, (
f"evict_blocks should not be called for async-only invalid blocks, "
f"but was called {len(evict_blocks_calls)} time(s) with {evict_blocks_calls}"
)
# request should still be waiting (not finished with error due to recompute policy)
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.request_id in recompute_scheduler.failed_recving_kv_req_ids
# verify num_computed_tokens was truncated to before invalid block
expected_valid_tokens = invalid_block_idx * recompute_scheduler.block_size
assert request.num_computed_tokens == expected_valid_tokens
# verify invalid block still has no hash (was not evicted)
assert block.block_hash is None, (
f"Async loading blocks shouldn't be cached or evicted. "
f"Block {invalid_block_id} hash should be None but is {block.block_hash}"
)
# now simulate async transfer completing
model_runner_output_2 = create_model_runner_output(
reqs=[],
finished_recving={request.request_id},
invalid_block_ids=None,
use_eos=False,
)
recompute_scheduler.update_from_output(scheduler_output, model_runner_output_2)
# verify request is now marked as finished receiving and ready to be processed
assert request.request_id in recompute_scheduler.finished_recving_kv_req_ids
assert request.request_id in recompute_scheduler.failed_recving_kv_req_ids
# critical: verify invalid block still has no hash before recompute
# the async transfer invalid data was never cached
assert block.block_hash is None, (
f"Invalid block {invalid_block_id} should not be cached before recompute "
f"(hash should be None), but hash is {block.block_hash}"
)
# critical end-to-end test: spy on cache_blocks to verify it's called with
# the truncated num_computed_tokens value
original_cache_blocks = recompute_scheduler.kv_cache_manager.cache_blocks
cache_blocks_calls = []
def cache_blocks_spy(req, num_tokens):
cache_blocks_calls.append((req.request_id, num_tokens))
return original_cache_blocks(req, num_tokens)
with patch.object(
recompute_scheduler.kv_cache_manager, "cache_blocks", cache_blocks_spy
):
# call schedule() again - this triggers _update_waiting_for_remote_kv()
# which should call cache_blocks with the truncated value
recompute_scheduler.schedule()
# verify cache_blocks was called with the truncated value
assert len(cache_blocks_calls) == 1, (
f"cache_blocks should be called exactly once, "
f"got {len(cache_blocks_calls)} calls"
)
cached_req_id, cached_num_tokens = cache_blocks_calls[0]
assert cached_req_id == request.request_id
assert cached_num_tokens == expected_valid_tokens, (
f"cache_blocks should be called with truncated value {expected_valid_tokens}, "
f"but was called with {cached_num_tokens}"
)
# request should now be RUNNING (scheduled immediately after transfer completes)
# the flow is: WAITING_FOR_REMOTE_KVS -> WAITING -> RUNNING in same schedule() call
assert request.status == RequestStatus.RUNNING
# num_computed_tokens should be >= expected_valid_tokens because the scheduler
# will schedule additional new tokens (up to max_num_batched_tokens) for the request
assert request.num_computed_tokens >= expected_valid_tokens, (
f"num_computed_tokens should be at least {expected_valid_tokens}, "
f"got {request.num_computed_tokens}"
)
# request should no longer be in the failed/finished receiving sets
assert request.request_id not in recompute_scheduler.failed_recving_kv_req_ids
assert request.request_id not in recompute_scheduler.finished_recving_kv_req_ids
# request should be in the running queue
assert request in recompute_scheduler.running
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock
import pytest
from vllm.distributed.kv_events import BlockStored
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import (
LMCacheConnectorV1,
LMCacheKVEvents,
)
from vllm.v1.outputs import KVConnectorOutput
@pytest.fixture
def mock_lmcache_engine_event():
"""Create a mock event object that mimics what the lmcache engine returns."""
class MockEvent:
def __init__(
self,
block_hashes,
parent_block_hash,
token_ids,
lora_id,
block_size,
medium,
):
self.block_hashes = block_hashes
self.parent_block_hash = parent_block_hash
self.token_ids = token_ids
self.lora_id = lora_id
self.block_size = block_size
self.medium = medium
return MockEvent(
block_hashes=["hash1", "hash2"],
parent_block_hash="parent_hash",
token_ids=[1, 2, 3, 4],
lora_id=None,
block_size=16,
medium="GPU",
)
@pytest.fixture
def mock_connector():
"""Create a mock LMCacheConnectorV1 instance with mocked dependencies."""
connector = MagicMock(spec=LMCacheConnectorV1)
connector._kv_cache_events = None
connector._lmcache_engine = MagicMock()
# Make the methods use the real implementation
connector.get_kv_connector_kv_cache_events = (
LMCacheConnectorV1.get_kv_connector_kv_cache_events.__get__(
connector, LMCacheConnectorV1
)
)
connector.update_connector_output = (
LMCacheConnectorV1.update_connector_output.__get__(
connector, LMCacheConnectorV1
)
)
connector.take_events = LMCacheConnectorV1.take_events.__get__(
connector, LMCacheConnectorV1
)
return connector
class TestGetKVConnectorKVCacheEvents:
"""Test get_kv_connector_kv_cache_events method."""
def test_returns_none_when_no_events(self, mock_connector):
"""Test that None is returned when lmcache engine has no events."""
mock_connector._lmcache_engine.get_kv_events.return_value = None
result = mock_connector.get_kv_connector_kv_cache_events()
assert result is None
mock_connector._lmcache_engine.get_kv_events.assert_called_once()
def test_returns_none_when_empty_list(self, mock_connector):
"""Test that None is returned when lmcache engine returns empty list."""
mock_connector._lmcache_engine.get_kv_events.return_value = []
result = mock_connector.get_kv_connector_kv_cache_events()
assert result is None
def test_converts_single_event(self, mock_connector, mock_lmcache_engine_event):
"""Test conversion of a single event from lmcache engine format."""
mock_connector._lmcache_engine.get_kv_events.return_value = [
mock_lmcache_engine_event
]
result = mock_connector.get_kv_connector_kv_cache_events()
assert result is not None
assert isinstance(result, LMCacheKVEvents)
assert result.get_number_of_workers() == 1
events = result.get_all_events()
assert len(events) == 1
assert isinstance(events[0], BlockStored)
assert events[0].block_hashes == ["hash1", "hash2"]
assert events[0].parent_block_hash == "parent_hash"
assert events[0].token_ids == [1, 2, 3, 4]
assert events[0].lora_id is None
assert events[0].block_size == 16
assert events[0].medium == "GPU"
def test_converts_multiple_events(self, mock_connector):
"""Test conversion of multiple events from lmcache engine format."""
class MockEvent:
def __init__(self, i):
self.block_hashes = [f"hash{i}"]
self.parent_block_hash = f"parent{i}"
self.token_ids = [i]
self.lora_id = None
self.block_size = 16
self.medium = "GPU"
events = [MockEvent(i) for i in range(5)]
mock_connector._lmcache_engine.get_kv_events.return_value = events
result = mock_connector.get_kv_connector_kv_cache_events()
assert result is not None
assert isinstance(result, LMCacheKVEvents)
converted_events = result.get_all_events()
assert len(converted_events) == 5
for i, event in enumerate(converted_events):
assert isinstance(event, BlockStored)
assert event.block_hashes == [f"hash{i}"]
assert event.parent_block_hash == f"parent{i}"
assert event.token_ids == [i]
def test_preserves_event_attributes(self, mock_connector):
"""Test that all event attributes are correctly preserved."""
class MockEventWithLora:
def __init__(self):
self.block_hashes = ["hash_a", "hash_b", "hash_c"]
self.parent_block_hash = "parent_xyz"
self.token_ids = [100, 200, 300]
self.lora_id = 42
self.block_size = 32
self.medium = "DISK"
mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEventWithLora()
]
result = mock_connector.get_kv_connector_kv_cache_events()
events = result.get_all_events()
event = events[0]
assert event.block_hashes == ["hash_a", "hash_b", "hash_c"]
assert event.parent_block_hash == "parent_xyz"
assert event.token_ids == [100, 200, 300]
assert event.lora_id == 42
assert event.block_size == 32
assert event.medium == "DISK"
def test_handles_none_parent_block_hash(self, mock_connector):
"""Test handling of events with None parent_block_hash."""
class MockEventNoParent:
def __init__(self):
self.block_hashes = ["hash1"]
self.parent_block_hash = None
self.token_ids = [1, 2]
self.lora_id = None
self.block_size = 16
self.medium = "GPU"
mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEventNoParent()
]
result = mock_connector.get_kv_connector_kv_cache_events()
events = result.get_all_events()
assert events[0].parent_block_hash is None
class TestUpdateConnectorOutput:
"""Test update_connector_output method."""
def test_does_nothing_when_kv_cache_events_is_none(self, mock_connector):
"""Test that method returns early when kv_cache_events is None."""
connector_output = KVConnectorOutput(kv_cache_events=None)
mock_connector.update_connector_output(connector_output)
assert mock_connector._kv_cache_events is None
def test_does_nothing_when_kv_cache_events_is_not_lmcache_kv_events(
self, mock_connector
):
"""Test that method returns early when kv_cache_events is not
LMCacheKVEvents."""
# Create a mock object that is not LMCacheKVEvents
fake_events = MagicMock()
connector_output = KVConnectorOutput(kv_cache_events=fake_events)
mock_connector.update_connector_output(connector_output)
assert mock_connector._kv_cache_events is None
def test_sets_kv_cache_events_when_none(self, mock_connector):
"""Test that _kv_cache_events is set when it was None."""
kv_events = LMCacheKVEvents(num_workers=1)
event = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1, 2],
block_size=16,
lora_id=None,
medium="GPU",
)
kv_events.add_events([event])
connector_output = KVConnectorOutput(kv_cache_events=kv_events)
mock_connector.update_connector_output(connector_output)
assert mock_connector._kv_cache_events is kv_events
def test_adds_events_when_kv_cache_events_already_exists(self, mock_connector):
"""Test that events are added when _kv_cache_events already exists."""
# Set up existing events
existing_events = LMCacheKVEvents(num_workers=2)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
existing_events.add_events([event1])
existing_events.add_events([event1]) # Simulate 2 workers reporting
mock_connector._kv_cache_events = existing_events
# Create new events to add
new_events = LMCacheKVEvents(num_workers=1)
event2 = BlockStored(
block_hashes=["hash2"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
new_events.add_events([event2])
connector_output = KVConnectorOutput(kv_cache_events=new_events)
mock_connector.update_connector_output(connector_output)
# Check that events were added
all_events = mock_connector._kv_cache_events.get_all_events()
assert len(all_events) == 3 # 2 from existing + 1 from new
assert event1 in all_events
assert event2 in all_events
def test_increments_workers_when_kv_cache_events_already_exists(
self, mock_connector
):
"""Test that worker count is incremented correctly."""
# Set up existing events with 2 workers
existing_events = LMCacheKVEvents(num_workers=2)
mock_connector._kv_cache_events = existing_events
# Create new events from 3 workers
new_events = LMCacheKVEvents(num_workers=3)
event = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
new_events.add_events([event])
connector_output = KVConnectorOutput(kv_cache_events=new_events)
mock_connector.update_connector_output(connector_output)
# Worker count should be 2 + 3 = 5
assert mock_connector._kv_cache_events.get_number_of_workers() == 5
def test_multiple_updates(self, mock_connector):
"""Test multiple consecutive updates."""
# First update
events1 = LMCacheKVEvents(num_workers=1)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
events1.add_events([event1])
output1 = KVConnectorOutput(kv_cache_events=events1)
mock_connector.update_connector_output(output1)
# Second update
events2 = LMCacheKVEvents(num_workers=2)
event2 = BlockStored(
block_hashes=["hash2"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
events2.add_events([event2])
output2 = KVConnectorOutput(kv_cache_events=events2)
mock_connector.update_connector_output(output2)
# Third update
events3 = LMCacheKVEvents(num_workers=1)
event3 = BlockStored(
block_hashes=["hash3"],
parent_block_hash=None,
token_ids=[3],
block_size=16,
lora_id=None,
medium="GPU",
)
events3.add_events([event3])
output3 = KVConnectorOutput(kv_cache_events=events3)
mock_connector.update_connector_output(output3)
# Check final state
all_events = mock_connector._kv_cache_events.get_all_events()
assert len(all_events) == 3
assert mock_connector._kv_cache_events.get_number_of_workers() == 4 # 1+2+1
def test_updates_with_empty_events(self, mock_connector):
"""Test updating with empty event lists."""
# First update with actual events
events1 = LMCacheKVEvents(num_workers=1)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
events1.add_events([event1])
output1 = KVConnectorOutput(kv_cache_events=events1)
mock_connector.update_connector_output(output1)
# Second update with empty events
events2 = LMCacheKVEvents(num_workers=2)
# No events added
output2 = KVConnectorOutput(kv_cache_events=events2)
mock_connector.update_connector_output(output2)
# Should still have the original event
all_events = mock_connector._kv_cache_events.get_all_events()
assert len(all_events) == 1
assert mock_connector._kv_cache_events.get_number_of_workers() == 3
class TestTakeEvents:
"""Test take_events method."""
def test_yields_nothing_when_kv_cache_events_is_none(self, mock_connector):
"""Test that nothing is yielded when _kv_cache_events is None."""
mock_connector._kv_cache_events = None
events = list(mock_connector.take_events())
assert events == []
def test_yields_events_and_clears(self, mock_connector):
"""Test that events are yielded and then cleared."""
# Set up events
kv_events = LMCacheKVEvents(num_workers=1)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
event2 = BlockStored(
block_hashes=["hash2"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
kv_events.add_events([event1, event2])
mock_connector._kv_cache_events = kv_events
# Take events
events = list(mock_connector.take_events())
# Check that events were yielded
assert len(events) == 2
assert event1 in events
assert event2 in events
# Check that _kv_cache_events was cleared
assert mock_connector._kv_cache_events is None
def test_aggregates_before_yielding(self, mock_connector):
"""Test that events are aggregated before yielding."""
# Set up events from multiple workers
kv_events = LMCacheKVEvents(num_workers=3)
common_event = BlockStored(
block_hashes=["hash_common"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
uncommon_event = BlockStored(
block_hashes=["hash_uncommon"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
# All 3 workers report common_event
kv_events.add_events([common_event])
kv_events.add_events([common_event])
kv_events.add_events([common_event])
# Only 1 worker reports uncommon_event
kv_events.add_events([uncommon_event])
mock_connector._kv_cache_events = kv_events
# Take events
events = list(mock_connector.take_events())
# Only the common event should be yielded
assert len(events) == 1
assert events[0] == common_event
def test_multiple_take_events_calls(self, mock_connector):
"""Test calling take_events multiple times."""
# First call with events
kv_events1 = LMCacheKVEvents(num_workers=1)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
kv_events1.add_events([event1])
mock_connector._kv_cache_events = kv_events1
events1 = list(mock_connector.take_events())
assert len(events1) == 1
assert events1[0] == event1
assert mock_connector._kv_cache_events is None
# Second call with no events
events2 = list(mock_connector.take_events())
assert events2 == []
# Third call after adding new events
kv_events2 = LMCacheKVEvents(num_workers=1)
event2 = BlockStored(
block_hashes=["hash2"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
kv_events2.add_events([event2])
mock_connector._kv_cache_events = kv_events2
events3 = list(mock_connector.take_events())
assert len(events3) == 1
assert events3[0] == event2
def test_yields_empty_after_aggregation_removes_all(self, mock_connector):
"""Test that nothing is yielded if aggregation removes all events."""
# Set up events from 2 workers with no common events
kv_events = LMCacheKVEvents(num_workers=2)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
event2 = BlockStored(
block_hashes=["hash2"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
# Worker 1 reports event1
kv_events.add_events([event1])
# Worker 2 reports event2
kv_events.add_events([event2])
mock_connector._kv_cache_events = kv_events
# Take events
events = list(mock_connector.take_events())
# No common events, so nothing should be yielded
assert events == []
assert mock_connector._kv_cache_events is None
class TestIntegrationScenarios:
"""Test integration scenarios."""
def test_full_workflow(self, mock_connector, mock_lmcache_engine_event):
"""Test a complete workflow from getting events to taking them."""
# Step 1: Get events from lmcache engine
mock_connector._lmcache_engine.get_kv_events.return_value = [
mock_lmcache_engine_event
]
kv_events = mock_connector.get_kv_connector_kv_cache_events()
assert kv_events is not None
assert len(kv_events.get_all_events()) == 1
# Step 2: Update connector output (simulate receiving from worker)
output1 = KVConnectorOutput(kv_cache_events=kv_events)
mock_connector.update_connector_output(output1)
assert mock_connector._kv_cache_events is not None
# Step 3: Take events
taken_events = list(mock_connector.take_events())
assert len(taken_events) == 1
assert mock_connector._kv_cache_events is None
def test_multiple_workers_workflow(self, mock_connector):
"""Test workflow with multiple workers."""
class MockEvent:
def __init__(self, hash_val):
self.block_hashes = [hash_val]
self.parent_block_hash = None
self.token_ids = [1]
self.lora_id = None
self.block_size = 16
self.medium = "GPU"
# Worker 1
mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEvent("hash_common"),
MockEvent("hash_worker1"),
]
kv_events1 = mock_connector.get_kv_connector_kv_cache_events()
output1 = KVConnectorOutput(kv_cache_events=kv_events1)
mock_connector.update_connector_output(output1)
# Worker 2
mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEvent("hash_common"),
MockEvent("hash_worker2"),
]
kv_events2 = mock_connector.get_kv_connector_kv_cache_events()
output2 = KVConnectorOutput(kv_cache_events=kv_events2)
mock_connector.update_connector_output(output2)
# Take events (should only get common events)
taken_events = list(mock_connector.take_events())
# With aggregation, only events reported by both workers should be present
# In this case, hash_common was reported by both
event_hashes = [e.block_hashes[0] for e in taken_events]
assert "hash_common" in event_hashes
def test_empty_workflow(self, mock_connector):
"""Test workflow when there are no events at any stage."""
# Get events returns None
mock_connector._lmcache_engine.get_kv_events.return_value = None
kv_events = mock_connector.get_kv_connector_kv_cache_events()
assert kv_events is None
# Update with None
output = KVConnectorOutput(kv_cache_events=None)
mock_connector.update_connector_output(output)
# Take events
taken_events = list(mock_connector.take_events())
assert taken_events == []
assert mock_connector._kv_cache_events is None
def test_repeated_cycles(self, mock_connector):
"""Test multiple cycles of the complete workflow."""
class MockEvent:
def __init__(self, cycle_num):
self.block_hashes = [f"hash_cycle_{cycle_num}"]
self.parent_block_hash = None
self.token_ids = [cycle_num]
self.lora_id = None
self.block_size = 16
self.medium = "GPU"
for cycle in range(3):
# Get events
mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEvent(cycle)
]
kv_events = mock_connector.get_kv_connector_kv_cache_events()
# Update
output = KVConnectorOutput(kv_cache_events=kv_events)
mock_connector.update_connector_output(output)
# Take
taken_events = list(mock_connector.take_events())
# Verify
assert len(taken_events) == 1
assert taken_events[0].block_hashes[0] == f"hash_cycle_{cycle}"
assert mock_connector._kv_cache_events is None
def test_lmcache_kv_events_aggregation(self):
"""
Test LMCacheKVEvents aggregation across TP ranks using
KVOutputAggregator (used by MultiprocExecutor).
"""
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.v1.outputs import ModelRunnerOutput
# Create KVOutputAggregator for 3 workers (simulating TP=3)
aggregator = KVOutputAggregator(expected_finished_count=3)
# Define common and unique events
common_event = BlockStored(
block_hashes=["hash_common"],
parent_block_hash="parent_common",
token_ids=[1, 2, 3],
block_size=16,
lora_id=None,
medium="GPU",
)
worker1_unique_event = BlockStored(
block_hashes=["hash_worker1"],
parent_block_hash="parent_w1",
token_ids=[4, 5],
block_size=16,
lora_id=None,
medium="GPU",
)
worker2_unique_event = BlockStored(
block_hashes=["hash_worker2"],
parent_block_hash="parent_w2",
token_ids=[6, 7],
block_size=16,
lora_id=None,
medium="GPU",
)
worker3_unique_event = BlockStored(
block_hashes=["hash_worker3"],
parent_block_hash="parent_w3",
token_ids=[8, 9],
block_size=16,
lora_id=None,
medium="GPU",
)
# Create events for each worker
# Worker 0: reports common event and its unique event
worker0_events = LMCacheKVEvents(num_workers=1)
worker0_events.add_events([common_event, worker1_unique_event])
# Worker 1: reports common event and its unique event
worker1_events = LMCacheKVEvents(num_workers=1)
worker1_events.add_events([common_event, worker2_unique_event])
# Worker 2: reports common event and its unique event
worker2_events = LMCacheKVEvents(num_workers=1)
worker2_events.add_events([common_event, worker3_unique_event])
# Create ModelRunnerOutput instances for each worker
worker_outputs = []
for i, worker_events in enumerate(
[worker0_events, worker1_events, worker2_events]
):
output = ModelRunnerOutput(
req_ids=[f"req_{i}"],
req_id_to_index={f"req_{i}": 0},
sampled_token_ids=[[123]], # dummy token
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=set([f"req_{i}_send"])
if i < 2
else None, # Workers 0,1 finished sending
finished_recving=set([f"req_{i}_recv"])
if i > 0
else None, # Workers 1,2 finished receiving
kv_cache_events=worker_events,
),
)
worker_outputs.append(output)
# Use the real aggregation mechanism (like MultiprocExecutor.execute_model)
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
kv_cache_events = aggregated_output.kv_connector_output.kv_cache_events
assert isinstance(kv_cache_events, LMCacheKVEvents)
# After aggregation, events should be combined from all workers
# The aggregator doesn't automatically aggregate events, so we need to call
# aggregate() to get only common events
kv_cache_events.aggregate()
aggregated_events = kv_cache_events.get_all_events()
# Only the common event should remain after aggregation
# because it's the only event reported by all 3 workers
assert len(aggregated_events) == 1
assert aggregated_events[0] == common_event
# Verify the common event properties
assert aggregated_events[0].block_hashes == ["hash_common"]
assert aggregated_events[0].parent_block_hash == "parent_common"
assert aggregated_events[0].token_ids == [1, 2, 3]
...@@ -461,7 +461,7 @@ class TestNixlHandshake: ...@@ -461,7 +461,7 @@ class TestNixlHandshake:
metadata = NixlConnectorMetadata() metadata = NixlConnectorMetadata()
if num_xfers > 0: if num_xfers > 0:
num_xfers -= 1 num_xfers -= 1
metadata.add_new_req( metadata.add_new_req_to_recv(
request_id=request_id, request_id=request_id,
local_block_ids=[num_xfers + 1, num_xfers + 2, num_xfers + 3], local_block_ids=[num_xfers + 1, num_xfers + 2, num_xfers + 3],
kv_transfer_params={ kv_transfer_params={
...@@ -532,7 +532,7 @@ class TestNixlHandshake: ...@@ -532,7 +532,7 @@ class TestNixlHandshake:
vllm_config, connector.engine_id vllm_config, connector.engine_id
) )
metadata = NixlConnectorMetadata() metadata = NixlConnectorMetadata()
metadata.add_new_req( metadata.add_new_req_to_recv(
request_id="id", request_id="id",
local_block_ids=[1, 2, 3], local_block_ids=[1, 2, 3],
kv_transfer_params={ kv_transfer_params={
...@@ -588,7 +588,7 @@ class TestNixlHandshake: ...@@ -588,7 +588,7 @@ class TestNixlHandshake:
metadata = NixlConnectorMetadata() metadata = NixlConnectorMetadata()
total_reqs = 5 total_reqs = 5
for i in range(total_reqs): for i in range(total_reqs):
metadata.add_new_req( metadata.add_new_req_to_recv(
request_id=f"id_{i}", request_id=f"id_{i}",
local_block_ids=[1, 2, 3], local_block_ids=[1, 2, 3],
kv_transfer_params={ kv_transfer_params={
...@@ -752,7 +752,7 @@ def test_kv_connector_stats(dist_init): ...@@ -752,7 +752,7 @@ def test_kv_connector_stats(dist_init):
# Create transfer metadata # Create transfer metadata
request_id = "test_req_for_stats" request_id = "test_req_for_stats"
metadata = NixlConnectorMetadata() metadata = NixlConnectorMetadata()
metadata.add_new_req( metadata.add_new_req_to_recv(
request_id=request_id, request_id=request_id,
local_block_ids=[1, 2, 3], local_block_ids=[1, 2, 3],
kv_transfer_params={ kv_transfer_params={
...@@ -1515,7 +1515,7 @@ def test_handshake_failure_returns_finished(dist_init): ...@@ -1515,7 +1515,7 @@ def test_handshake_failure_returns_finished(dist_init):
request_id = "test_handshake_fail" request_id = "test_handshake_fail"
metadata = NixlConnectorMetadata() metadata = NixlConnectorMetadata()
metadata.add_new_req( metadata.add_new_req_to_recv(
request_id=request_id, request_id=request_id,
local_block_ids=[1, 2, 3], local_block_ids=[1, 2, 3],
kv_transfer_params={ kv_transfer_params={
...@@ -1565,7 +1565,7 @@ def test_transfer_setup_failure_returns_finished(dist_init): ...@@ -1565,7 +1565,7 @@ def test_transfer_setup_failure_returns_finished(dist_init):
request_id = "test_transfer_fail" request_id = "test_transfer_fail"
metadata = NixlConnectorMetadata() metadata = NixlConnectorMetadata()
metadata.add_new_req( metadata.add_new_req_to_recv(
request_id=request_id, request_id=request_id,
local_block_ids=[7, 8, 9], local_block_ids=[7, 8, 9],
kv_transfer_params={ kv_transfer_params={
......
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandlers
BACKENDS_TO_TEST = [FlashAttentionBackend] BACKENDS_TO_TEST = [FlashAttentionBackend]
...@@ -82,7 +82,7 @@ def test_transfer( ...@@ -82,7 +82,7 @@ def test_transfer(
# create handler # create handler
cpu_block_size = gpu_blocks_per_cpu_block * gpu_block_size cpu_block_size = gpu_blocks_per_cpu_block * gpu_block_size
handler = CpuGpuOffloadingHandler( handlers = CpuGpuOffloadingHandlers(
attn_backends=attn_backends, attn_backends=attn_backends,
gpu_block_size=gpu_block_size, gpu_block_size=gpu_block_size,
cpu_block_size=cpu_block_size, cpu_block_size=cpu_block_size,
...@@ -112,8 +112,7 @@ def test_transfer( ...@@ -112,8 +112,7 @@ def test_transfer(
# set transfer direction # set transfer direction
if gpu_to_cpu: if gpu_to_cpu:
src_kv_caches = handler.gpu_tensors handler = handlers.gpu_to_cpu_handler
dst_kv_caches = handler.cpu_tensors
src_spec_class = GPULoadStoreSpec src_spec_class = GPULoadStoreSpec
dst_spec_class = CPULoadStoreSpec dst_spec_class = CPULoadStoreSpec
src_blocks = gpu_blocks src_blocks = gpu_blocks
...@@ -122,8 +121,7 @@ def test_transfer( ...@@ -122,8 +121,7 @@ def test_transfer(
dst_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size dst_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size
dst_size_in_gpu_blocks = num_cpu_blocks * gpu_blocks_per_cpu_block dst_size_in_gpu_blocks = num_cpu_blocks * gpu_blocks_per_cpu_block
else: else:
src_kv_caches = handler.cpu_tensors handler = handlers.cpu_to_gpu_handler
dst_kv_caches = handler.gpu_tensors
src_spec_class = CPULoadStoreSpec src_spec_class = CPULoadStoreSpec
dst_spec_class = GPULoadStoreSpec dst_spec_class = GPULoadStoreSpec
src_blocks = cpu_blocks src_blocks = cpu_blocks
...@@ -144,12 +142,12 @@ def test_transfer( ...@@ -144,12 +142,12 @@ def test_transfer(
dst_spec = dst_spec_class(dst_blocks) dst_spec = dst_spec_class(dst_blocks)
# clone src and dst tensors before transfer # clone src and dst tensors before transfer
orig_src_caches = [x.clone() for x in src_kv_caches] orig_src_caches = [x.clone() for x in handler.src_tensors]
orig_dst_caches = [x.clone() for x in dst_kv_caches] orig_dst_caches = [x.clone() for x in handler.dst_tensors]
# call transfer function # call transfer function
assert handler.transfer_async(1, (src_spec, dst_spec)) assert handler.transfer_async(1, (src_spec, dst_spec))
assert set(handler.transfer_events.keys()) == {1} assert set({x[0] for x in handler._transfers}) == {1}
# wait for transfer to complete # wait for transfer to complete
end_time = time.time() + 10 end_time = time.time() + 10
...@@ -161,15 +159,15 @@ def test_transfer( ...@@ -161,15 +159,15 @@ def test_transfer(
time.sleep(0.1) time.sleep(0.1)
# verify src tensors did not change # verify src tensors did not change
for orig_tensor, tensor in zip(orig_src_caches, src_kv_caches): for orig_tensor, tensor in zip(orig_src_caches, handler.src_tensors):
assert torch.equal(orig_tensor, tensor) assert torch.equal(orig_tensor, tensor)
# verify dst tensors # verify dst tensors
for dst_block in range(dst_size_in_gpu_blocks): for dst_block in range(dst_size_in_gpu_blocks):
src_block_candidate = dst_to_src.get(dst_block) src_block_candidate = dst_to_src.get(dst_block)
for src_cache, dst_cache, orig_dst_cache, kv_dim in zip( for src_cache, dst_cache, orig_dst_cache, kv_dim in zip(
src_kv_caches, handler.src_tensors,
dst_kv_caches, handler.dst_tensors,
orig_dst_caches, orig_dst_caches,
handler.kv_dim_before_num_blocks, handler.kv_dim_before_num_blocks,
): ):
......
...@@ -528,9 +528,11 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode): ...@@ -528,9 +528,11 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
), ),
], ],
) )
@pytest.mark.parametrize("top_logprobs", [0, 3])
def test_spec_decode_logprobs( def test_spec_decode_logprobs(
logprobs_mode: LogprobsMode, logprobs_mode: LogprobsMode,
model_setup: tuple[str, str, str], model_setup: tuple[str, str, str],
top_logprobs: int,
): ):
"""Spec decode logprobs should match those of the base model. """Spec decode logprobs should match those of the base model.
...@@ -543,7 +545,7 @@ def test_spec_decode_logprobs( ...@@ -543,7 +545,7 @@ def test_spec_decode_logprobs(
prompt = "Hello world " * 50 prompt = "Hello world " * 50
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False
) )
method, model_name, spec_model_name = model_setup method, model_name, spec_model_name = model_setup
max_model_len = 256 max_model_len = 256
......
...@@ -111,7 +111,7 @@ def create_sampling_metadata( ...@@ -111,7 +111,7 @@ def create_sampling_metadata(
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
generators=generators, generators=generators,
max_num_logprobs=0, max_num_logprobs=None,
no_penalties=no_penalties, no_penalties=no_penalties,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
frequency_penalties=frequency_penalties, frequency_penalties=frequency_penalties,
......
...@@ -44,8 +44,6 @@ def unsupported_array_schemas(): ...@@ -44,8 +44,6 @@ def unsupported_array_schemas():
@pytest.fixture @pytest.fixture
def unsupported_object_schemas(): def unsupported_object_schemas():
return [ return [
{"type": "object", "minProperties": 1},
{"type": "object", "maxProperties": 5},
{"type": "object", "propertyNames": {"pattern": "^[a-z]+$"}}, {"type": "object", "propertyNames": {"pattern": "^[a-z]+$"}},
{"type": "object", "patternProperties": {"^S": {"type": "string"}}}, {"type": "object", "patternProperties": {"^S": {"type": "string"}}},
] ]
...@@ -79,6 +77,8 @@ def supported_schema(): ...@@ -79,6 +77,8 @@ def supported_schema():
}, },
}, },
}, },
"minProperties": 1,
"maxProperties": 100,
} }
......
...@@ -7,7 +7,7 @@ Here we break down the requirements in 2 steps: ...@@ -7,7 +7,7 @@ Here we break down the requirements in 2 steps:
1. Build and install the Python libraries (both [pplx-kernels](https://github.com/ppl-ai/pplx-kernels) and [DeepEP](https://github.com/deepseek-ai/DeepEP)), including necessary dependencies like NVSHMEM. This step does not require any privileged access. Any user can do this. 1. Build and install the Python libraries (both [pplx-kernels](https://github.com/ppl-ai/pplx-kernels) and [DeepEP](https://github.com/deepseek-ai/DeepEP)), including necessary dependencies like NVSHMEM. This step does not require any privileged access. Any user can do this.
2. Configure NVIDIA driver to enable IBGDA. This step requires root access, and must be done on the host machine. 2. Configure NVIDIA driver to enable IBGDA. This step requires root access, and must be done on the host machine.
2 is necessary for multi-node deployment. Step 2 is necessary for multi-node deployment.
All scripts accept a positional argument as workspace path for staging the build, defaulting to `$(pwd)/ep_kernels_workspace`. All scripts accept a positional argument as workspace path for staging the build, defaulting to `$(pwd)/ep_kernels_workspace`.
...@@ -23,6 +23,6 @@ TORCH_CUDA_ARCH_LIST="10.0" bash install_python_libraries.sh ...@@ -23,6 +23,6 @@ TORCH_CUDA_ARCH_LIST="10.0" bash install_python_libraries.sh
Additional step for multi-node deployment: Additional step for multi-node deployment:
```bash ```bash
sudo bash configure_system_drivers.sh sudo bash configure_system_drivers.sh # update-initramfs can take several minutes
sudo reboot # Reboot is required to load the new driver sudo reboot # Reboot is required to load the new driver
``` ```
...@@ -43,6 +43,7 @@ FILES = [ ...@@ -43,6 +43,7 @@ FILES = [
"vllm/worker", "vllm/worker",
"vllm/v1/core", "vllm/v1/core",
"vllm/v1/engine", "vllm/v1/engine",
"vllm/v1/executor",
"vllm/v1/metrics", "vllm/v1/metrics",
"vllm/v1/pool", "vllm/v1/pool",
"vllm/v1/sample", "vllm/v1/sample",
...@@ -60,7 +61,6 @@ SEPARATE_GROUPS = [ ...@@ -60,7 +61,6 @@ SEPARATE_GROUPS = [
"vllm/model_executor", "vllm/model_executor",
# v1 related # v1 related
"vllm/v1/attention", "vllm/v1/attention",
"vllm/v1/executor",
"vllm/v1/kv_offload", "vllm/v1/kv_offload",
"vllm/v1/spec_decode", "vllm/v1/spec_decode",
"vllm/v1/structured_output", "vllm/v1/structured_output",
......
...@@ -503,15 +503,15 @@ def awq_dequantize( ...@@ -503,15 +503,15 @@ def awq_dequantize(
def awq_gemm( def awq_gemm(
input: torch.Tensor, input: torch.Tensor,
qweight: torch.Tensor, qweight: torch.Tensor,
qzeros: torch.Tensor,
scales: torch.Tensor, scales: torch.Tensor,
qzeros: torch.Tensor,
split_k_iters: int, split_k_iters: int,
) -> torch.Tensor: ) -> torch.Tensor:
if envs.VLLM_USE_TRITON_AWQ: if envs.VLLM_USE_TRITON_AWQ:
from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton
return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters) return awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters)
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters) return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters)
# gptq # gptq
...@@ -637,8 +637,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): ...@@ -637,8 +637,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
def _awq_gemm_fake( def _awq_gemm_fake(
input: torch.Tensor, input: torch.Tensor,
qweight: torch.Tensor, qweight: torch.Tensor,
qzeros: torch.Tensor,
scales: torch.Tensor, scales: torch.Tensor,
qzeros: torch.Tensor,
split_k_iters: torch.SymInt, split_k_iters: torch.SymInt,
) -> torch.Tensor: ) -> torch.Tensor:
num_in_feats = input.size(0) num_in_feats = input.size(0)
...@@ -2408,6 +2408,29 @@ def cp_gather_cache( ...@@ -2408,6 +2408,29 @@ def cp_gather_cache(
) )
def cp_gather_and_upconvert_fp8_kv_cache(
src_cache: torch.Tensor,
dst: torch.Tensor,
block_table: torch.Tensor,
seq_lens: torch.Tensor,
workspace_starts: torch.Tensor,
batch_size: int,
) -> None:
"""Gather and upconvert FP8 KV cache to BF16 workspace.
Args:
src_cache: FP8 KV cache [num_blocks, block_size, 656]
dst: BF16 output workspace [total_tokens, 576]
block_table: Block indices [num_reqs, max_blocks]
seq_lens: Sequence lengths [num_reqs]
workspace_starts: Workspace start offsets [num_reqs]
batch_size: Number of requests
"""
torch.ops._C_cache_ops.cp_gather_and_upconvert_fp8_kv_cache(
src_cache, dst, block_table, seq_lens, workspace_starts, batch_size
)
def indexer_k_quant_and_cache( def indexer_k_quant_and_cache(
k: torch.Tensor, k: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
......
...@@ -294,6 +294,12 @@ class AttentionImpl(ABC, Generic[T]): ...@@ -294,6 +294,12 @@ class AttentionImpl(ABC, Generic[T]):
# Some features like decode context parallelism require the softmax lse. # Some features like decode context parallelism require the softmax lse.
can_return_lse_for_decode: bool = False can_return_lse_for_decode: bool = False
# Whether the attention impl supports Prefill Context Parallelism.
supports_pcp: bool = False
# Whether the attention impl(or ops) supports MTP
# when cp_kv_cache_interleave_size > 1
supports_mtp_with_cp_non_trivial_interleave_size: bool = False
# some attention backends might not always want to return lse # some attention backends might not always want to return lse
# even if they can return lse (for efficiency reasons) # even if they can return lse (for efficiency reasons)
need_to_return_lse_for_decode: bool = False need_to_return_lse_for_decode: bool = False
......
...@@ -252,35 +252,3 @@ def register_backend( ...@@ -252,35 +252,3 @@ def register_backend(
return lambda x: x return lambda x: x
return decorator return decorator
# Backwards compatibility alias for plugins
class _BackendMeta(type):
"""Metaclass to provide deprecation warnings when accessing _Backend."""
def __getattribute__(cls, name: str):
if name not in ("__class__", "__mro__", "__name__"):
logger.warning(
"_Backend has been renamed to AttentionBackendEnum. "
"Please update your code to use AttentionBackendEnum instead. "
"_Backend will be removed in a future release."
)
return getattr(AttentionBackendEnum, name)
def __getitem__(cls, name: str):
logger.warning(
"_Backend has been renamed to AttentionBackendEnum. "
"Please update your code to use AttentionBackendEnum instead. "
"_Backend will be removed in a future release."
)
return AttentionBackendEnum[name]
class _Backend(metaclass=_BackendMeta):
"""Deprecated: Use AttentionBackendEnum instead.
This class is provided for backwards compatibility with plugins
and will be removed in a future release.
"""
pass
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer.""" """Attention layer."""
from collections.abc import Callable import functools
from typing import cast from typing import cast
import torch import torch
...@@ -16,7 +16,9 @@ from vllm.attention.backends.abstract import ( ...@@ -16,7 +16,9 @@ from vllm.attention.backends.abstract import (
MLAAttentionImpl, MLAAttentionImpl,
) )
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layers.mm_encoder_attention import maybe_get_vit_flash_attn_backend
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
from vllm.config import CacheConfig, get_current_vllm_config from vllm.config import CacheConfig, get_current_vllm_config
...@@ -47,58 +49,9 @@ from vllm.v1.kv_cache_interface import ( ...@@ -47,58 +49,9 @@ from vllm.v1.kv_cache_interface import (
SlidingWindowSpec, SlidingWindowSpec,
) )
if current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx9
else:
on_gfx9 = lambda *args, **kwargs: False
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__) logger = init_logger(__name__)
def maybe_get_vit_flash_attn_backend(
attn_backend: AttentionBackendEnum,
attn_backend_override: AttentionBackendEnum | None = None,
) -> tuple[AttentionBackendEnum, Callable | None]:
if current_platform.is_rocm():
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
attn_backend = AttentionBackendEnum.ROCM_AITER_FA
elif (
attn_backend_override is None
and on_gfx9()
and attn_backend == AttentionBackendEnum.FLASH_ATTN
):
pass
else:
return AttentionBackendEnum.TORCH_SDPA, None
elif current_platform.is_cuda():
pass
elif current_platform.is_xpu():
assert attn_backend == AttentionBackendEnum.FLASH_ATTN, (
"XPU platform only supports FLASH_ATTN as vision attention backend."
)
pass
else:
return AttentionBackendEnum.TORCH_SDPA, None
if attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
try:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
else:
flash_attn_varlen_func = None
return attn_backend, flash_attn_varlen_func
def _init_kv_cache_quant( def _init_kv_cache_quant(
layer: nn.Module, layer: nn.Module,
quant_config: QuantizationConfig | None, quant_config: QuantizationConfig | None,
...@@ -494,29 +447,15 @@ class MultiHeadAttention(nn.Module): ...@@ -494,29 +447,15 @@ class MultiHeadAttention(nn.Module):
attn_backend_override = None attn_backend_override = None
if multimodal_config is not None: if multimodal_config is not None:
attn_backend_override = multimodal_config.mm_encoder_attn_backend attn_backend_override = multimodal_config.mm_encoder_attn_backend
backend = get_vit_attn_backend(
self.attn_backend = get_vit_attn_backend(
head_size=head_size, head_size=head_size,
dtype=dtype, dtype=dtype,
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
self.attn_backend = ( self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
backend self.attn_backend,
if backend
in {
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.PALLAS,
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.FLASH_ATTN,
}
else AttentionBackendEnum.TORCH_SDPA
)
self.attn_backend, self._flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
attn_backend_override=attn_backend_override,
)
) )
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
...@@ -524,6 +463,17 @@ class MultiHeadAttention(nn.Module): ...@@ -524,6 +463,17 @@ class MultiHeadAttention(nn.Module):
AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} }
self.fa_version = None
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
and current_platform.is_cuda()
):
self.fa_version = get_flash_attn_version()
assert self._flash_attn_varlen_func is not None
self._flash_attn_varlen_func = functools.partial(
self._flash_attn_varlen_func, fa_version=self.fa_version
)
logger.info_once( logger.info_once(
f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder." f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder."
) )
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
)
from vllm.config import MultiModalConfig
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.models.vision import get_vit_attn_backend
logger = init_logger(__name__)
def maybe_get_vit_flash_attn_backend(
attn_backend: AttentionBackendEnum | None,
) -> Callable | None:
# At this point,
# we already have the attn_backend,
# overriding logic is done in the platform-specific implementation.
# so we don't need to override backend here.
# Just return the attn_backend and flash_attn_varlen_func.
if attn_backend == AttentionBackendEnum.FLASH_ATTN:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
elif attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
flash_attn_varlen_func = None
# if attn_backend is TORCH_SDPA,
# it will reach here and the flash_attn_varlen_func will be None.
return flash_attn_varlen_func
@CustomOp.register("mm_encoder_attn")
class MMEncoderAttention(CustomOp):
"""Multi-headed attention without any cache, used for multimodal encoder."""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float | None = None,
num_kv_heads: int | None = None,
prefix: str = "",
multimodal_config: MultiModalConfig | None = None,
) -> None:
"""
Args:
num_heads: number of attention heads per partition.
head_size: hidden_size per attention head.
scale: scale factor.
num_kv_heads: number of kv heads.
prefix: This has no effect, it is only here to make it easier to
swap between Attention and MultiHeadAttention
multimodal_config: configs for multi-modal.
"""
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.layer_name = prefix
assert self.num_heads % self.num_kv_heads == 0, (
f"num_heads ({self.num_heads}) is not "
f"divisible by num_kv_heads ({self.num_kv_heads})"
)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
# Try to get vision attention backend from multimodal_config.
attn_backend_override = None
if multimodal_config is not None:
attn_backend_override = multimodal_config.mm_encoder_attn_backend
# Get device-specific vision attention backend.
self.attn_backend = get_vit_attn_backend(
head_size=head_size,
dtype=dtype,
attn_backend_override=attn_backend_override,
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
self.attn_backend,
)
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
@classmethod
def enabled(cls) -> bool:
return True
def reshape_qkv_to_4d(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
bsz: int,
q_len: int,
kv_len: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Reshape query, key, value to 4D tensors:
(batch_size, seq_len, num_heads, head_size)
"""
query = query.view(bsz, q_len, self.num_heads, self.head_size)
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)
return query, key, value
def reshape_qkv_to_3d(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
bsz: int,
q_len: int,
kv_len: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Reshape query, key, value to 3D tensors:
(batch_size * seq_len, num_heads, head_size)
"""
query = query.view(bsz * q_len, self.num_heads, self.head_size)
key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size)
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=1)
value = torch.repeat_interleave(value, num_repeat, dim=1)
return query, key, value
def _forward_sdpa(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
# TODO(Isotr0py): Migrate MultiHeadAttention
assert cu_seqlens is not None
bsz, q_len = query.size()[:2]
kv_len = key.size(1)
query, key, value = self.reshape_qkv_to_4d(
query, key, value, bsz, q_len, kv_len
)
output = vit_torch_sdpa_wrapper(
q=query,
k=key,
v=value,
cu_seqlens=cu_seqlens,
)
return output
def _forward_fa(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
assert self.flash_attn_varlen_func is not None, (
"Flash attention function is not set."
)
# # TODO(Isotr0py): Migrate MultiHeadAttention
assert cu_seqlens is not None and max_seqlen is not None
bsz = query.shape[0]
output = vit_flash_attn_wrapper(
q=query,
k=key,
v=value,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=bsz,
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
)
return output
def forward_native(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
return self._forward_sdpa(query, key, value, cu_seqlens)
def forward_cuda(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
if self.is_flash_attn_backend:
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
return self._forward_sdpa(query, key, value, cu_seqlens)
else:
raise ValueError(
f"Unsupported multi-modal encoder attention backend for CUDA: "
f"{self.attn_backend}."
)
def forward_cpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
return self._forward_sdpa(query, key, value, cu_seqlens)
def forward_xpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
assert self.is_flash_attn_backend, (
"XPU only supports FLASH_ATTN for vision attention."
)
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
def forward_tpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
assert self.attn_backend == AttentionBackendEnum.PALLAS, (
f"MMEncoderAttention on TPU only supports PALLAS backend, "
f"but got {self.attn_backend}."
)
if cu_seqlens is None:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
return out
logger.warning_once(
"PALLAS backend with cu_seqlens is not supported for ViT yet. ",
"Falling back to SDPA implementation.",
)
return self._forward_sdpa(query, key, value, cu_seqlens)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment