Commit 7e63ef82 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0' into v0.14.0-dev

parents 8cbcac5d b17039bc
......@@ -9,7 +9,7 @@ from PIL import Image
from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import KVTransferConfig
from vllm.multimodal.utils import encode_image_base64
from vllm.multimodal.utils import encode_image_url
from vllm.platforms import current_platform
MODEL_NAME = "RedHatAI/Qwen2.5-VL-3B-Instruct-quantized.w8a8"
......@@ -74,7 +74,7 @@ def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]):
placeholders = [
{
"type": "image_url",
"image_url": {"url": f"data:image;base64,{encode_image_base64(image_pil)}"},
"image_url": {"url": encode_image_url(image_pil)},
}
for image_pil in image_urls
]
......@@ -145,7 +145,7 @@ def test_shared_storage_connector_hashes(tmp_path):
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor # noqa: F401
from transformers import AutoProcessor
# Create processor to handle the chat prompt
processor = AutoProcessor.from_pretrained(MODEL_NAME)
......
......@@ -25,6 +25,7 @@ def mock_lmcache_engine_event():
lora_id,
block_size,
medium,
lora_name,
):
self.block_hashes = block_hashes
self.parent_block_hash = parent_block_hash
......@@ -32,6 +33,7 @@ def mock_lmcache_engine_event():
self.lora_id = lora_id
self.block_size = block_size
self.medium = medium
self.lora_name = lora_name
return MockEvent(
block_hashes=["hash1", "hash2"],
......@@ -40,6 +42,7 @@ def mock_lmcache_engine_event():
lora_id=None,
block_size=16,
medium="GPU",
lora_name=None,
)
......@@ -109,6 +112,7 @@ class TestGetKVConnectorKVCacheEvents:
assert events[0].lora_id is None
assert events[0].block_size == 16
assert events[0].medium == "GPU"
assert events[0].lora_name is None
def test_converts_multiple_events(self, mock_connector):
"""Test conversion of multiple events from lmcache engine format."""
......@@ -121,6 +125,7 @@ class TestGetKVConnectorKVCacheEvents:
self.lora_id = None
self.block_size = 16
self.medium = "GPU"
self.lora_name = None
events = [MockEvent(i) for i in range(5)]
mock_connector._lmcache_engine.get_kv_events.return_value = events
......@@ -150,6 +155,7 @@ class TestGetKVConnectorKVCacheEvents:
self.lora_id = 42
self.block_size = 32
self.medium = "DISK"
self.lora_name = "lora_example"
mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEventWithLora()
......@@ -166,6 +172,7 @@ class TestGetKVConnectorKVCacheEvents:
assert event.lora_id == 42
assert event.block_size == 32
assert event.medium == "DISK"
assert event.lora_name == "lora_example"
def test_handles_none_parent_block_hash(self, mock_connector):
"""Test handling of events with None parent_block_hash."""
......@@ -178,6 +185,7 @@ class TestGetKVConnectorKVCacheEvents:
self.lora_id = None
self.block_size = 16
self.medium = "GPU"
self.lora_name = None
mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEventNoParent()
......@@ -223,6 +231,7 @@ class TestUpdateConnectorOutput:
block_size=16,
lora_id=None,
medium="GPU",
lora_name=None,
)
kv_events.add_events([event])
......@@ -243,6 +252,7 @@ class TestUpdateConnectorOutput:
block_size=16,
lora_id=None,
medium="GPU",
lora_name=None,
)
existing_events.add_events([event1])
existing_events.add_events([event1]) # Simulate 2 workers reporting
......@@ -258,6 +268,7 @@ class TestUpdateConnectorOutput:
block_size=16,
lora_id=None,
medium="GPU",
lora_name=None,
)
new_events.add_events([event2])
......@@ -288,6 +299,7 @@ class TestUpdateConnectorOutput:
block_size=16,
lora_id=None,
medium="GPU",
lora_name=None,
)
new_events.add_events([event])
......@@ -309,6 +321,7 @@ class TestUpdateConnectorOutput:
block_size=16,
lora_id=None,
medium="GPU",
lora_name=None,
)
events1.add_events([event1])
output1 = KVConnectorOutput(kv_cache_events=events1)
......@@ -323,6 +336,7 @@ class TestUpdateConnectorOutput:
block_size=16,
lora_id=None,
medium="GPU",
lora_name=None,
)
events2.add_events([event2])
output2 = KVConnectorOutput(kv_cache_events=events2)
......@@ -337,6 +351,7 @@ class TestUpdateConnectorOutput:
block_size=16,
lora_id=None,
medium="GPU",
lora_name=None,
)
events3.add_events([event3])
output3 = KVConnectorOutput(kv_cache_events=events3)
......@@ -358,6 +373,7 @@ class TestUpdateConnectorOutput:
block_size=16,
lora_id=None,
medium="GPU",
lora_name=None,
)
events1.add_events([event1])
output1 = KVConnectorOutput(kv_cache_events=events1)
......@@ -397,6 +413,7 @@ class TestTakeEvents:
block_size=16,
lora_id=None,
medium="GPU",
lora_name=None,
)
event2 = BlockStored(
block_hashes=["hash2"],
......@@ -405,6 +422,7 @@ class TestTakeEvents:
block_size=16,
lora_id=None,
medium="GPU",
lora_name=None,
)
kv_events.add_events([event1, event2])
mock_connector._kv_cache_events = kv_events
......@@ -431,6 +449,7 @@ class TestTakeEvents:
block_size=16,
lora_id=None,
medium="GPU",
lora_name=None,
)
uncommon_event = BlockStored(
block_hashes=["hash_uncommon"],
......@@ -439,6 +458,7 @@ class TestTakeEvents:
block_size=16,
lora_id=None,
medium="GPU",
lora_name=None,
)
# All 3 workers report common_event
......@@ -469,6 +489,7 @@ class TestTakeEvents:
block_size=16,
lora_id=None,
medium="GPU",
lora_name=None,
)
kv_events1.add_events([event1])
mock_connector._kv_cache_events = kv_events1
......@@ -491,6 +512,7 @@ class TestTakeEvents:
block_size=16,
lora_id=None,
medium="GPU",
lora_name=None,
)
kv_events2.add_events([event2])
mock_connector._kv_cache_events = kv_events2
......@@ -510,6 +532,7 @@ class TestTakeEvents:
block_size=16,
lora_id=None,
medium="GPU",
lora_name=None,
)
event2 = BlockStored(
block_hashes=["hash2"],
......@@ -518,6 +541,7 @@ class TestTakeEvents:
block_size=16,
lora_id=None,
medium="GPU",
lora_name=None,
)
# Worker 1 reports event1
......@@ -572,6 +596,7 @@ class TestIntegrationScenarios:
self.lora_id = None
self.block_size = 16
self.medium = "GPU"
self.lora_name = None
# Worker 1
mock_connector._lmcache_engine.get_kv_events.return_value = [
......@@ -628,6 +653,7 @@ class TestIntegrationScenarios:
self.lora_id = None
self.block_size = 16
self.medium = "GPU"
self.lora_name = None
for cycle in range(3):
# Get events
......@@ -667,6 +693,7 @@ class TestIntegrationScenarios:
block_size=16,
lora_id=None,
medium="GPU",
lora_name=None,
)
worker1_unique_event = BlockStored(
......@@ -676,6 +703,7 @@ class TestIntegrationScenarios:
block_size=16,
lora_id=None,
medium="GPU",
lora_name=None,
)
worker2_unique_event = BlockStored(
......@@ -685,6 +713,7 @@ class TestIntegrationScenarios:
block_size=16,
lora_id=None,
medium="GPU",
lora_name=None,
)
worker3_unique_event = BlockStored(
......@@ -694,6 +723,7 @@ class TestIntegrationScenarios:
block_size=16,
lora_id=None,
medium="GPU",
lora_name=None,
)
# Create events for each worker
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib.util
import os
from unittest.mock import MagicMock, patch
import msgspec
import pytest
import torch
import zmq
from tests.conftest import _find_free_port
from vllm.config import (
CacheConfig,
DeviceConfig,
KVTransferConfig,
ModelConfig,
SchedulerConfig,
VllmConfig,
)
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import (
MoRIIOAgentMetadata,
MoRIIOConnectorMetadata,
MoRIIOConstants,
zmq_ctx,
)
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import (
KVConnectorRole,
MoRIIOConnector,
MoRIIOConnectorWorker,
)
from vllm.platforms import current_platform
from vllm.utils.network_utils import (
get_ip,
make_zmq_path,
)
from .utils import create_request, create_scheduler
aiter_available = importlib.util.find_spec("aiter") is not None
mori_available = importlib.util.find_spec("mori") is not None
pytestmark = pytest.mark.skipif(
not (current_platform.is_rocm() and mori_available),
reason="MoRIIOs are only available on ROCm with aiter package installed",
)
@pytest.fixture
def mock_parallel_groups():
"""Mock tensor/data parallel group functions for single-rank tests."""
mock_group = MagicMock()
mock_group.rank = 0
mock_group.local_rank = 0
mock_group.world_size = 1
with (
patch.multiple(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common",
get_tensor_model_parallel_rank=MagicMock(return_value=0),
get_tensor_model_parallel_world_size=MagicMock(return_value=0),
),
patch.multiple(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector",
get_tensor_model_parallel_world_size=MagicMock(return_value=0),
get_world_group=MagicMock(return_value=mock_group),
get_tp_group=MagicMock(return_value=mock_group),
),
):
yield mock_group
def _setup_kv_transfer_request(request, remote_host="127.0.0.1", fake_port=4789):
"""Setup KV transfer parameters for a request."""
request.kv_transfer_params.update(
{
"remote_notify_port": fake_port,
"remote_block_ids": None,
"remote_host": remote_host,
"remote_port": fake_port,
"remote_handshake_port": fake_port,
"remote_engine_id": "test_engine",
}
)
return request
class FakeMorIIOWrapper:
# A fake MoRIIOWrapper for testing purposes
def __init__(self, *args, **kwargs):
pass
def set_moriio_engine(self, moriio_engine):
pass
def set_backend_type(self, backend_type):
pass
def get_agent_metadata(self):
pass
def register_remote_engine(self, remote_packed_engine_metadata):
pass
def register_local_tensor(self, tensor: torch.Tensor):
pass
def get_unpack_memory_metadata(self, packed_memory_metadata):
pass
def build_session(self, local_memory_metadata, remote_memory_metadata):
pass
def read_remote_data(
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
):
pass
def write_remote_data(
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
):
pass
def write_remote_data_single(
self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0
):
pass
def waiting_for_transfer_complete(self):
pass
def async_wait_reqid(self):
pass
def _handle_message(self, msg: bytes):
pass
def _handle_structured_message(self, data: dict):
pass
def _handle_completion_message(self, msg: str):
pass
def send_notify(self, req_ids, remote_ip, remote_port):
pass
def pop_finished_req_ids(self):
pass
def pop_finished_write_req_ids(self):
pass
def shutdown(self):
pass
class FakeMorIIOConnectorWorker(MoRIIOConnectorWorker):
# Define a fake remote engine id for testing
REMOTE_ENGINE_ID = "remote_engine"
def __init__(
self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs
):
super().__init__(*args, **kwargs)
def create_vllm_config(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 64,
block_size: int = 16,
max_model_len: int = 10000,
enable_chunked_prefill: bool = True,
enable_permute_local_kv: bool = False,
role="kv_consumer",
) -> VllmConfig:
"""Initialize VllmConfig for testing."""
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
enable_chunked_prefill=enable_chunked_prefill,
is_encoder_decoder=False,
)
model_config = ModelConfig(
model=model,
trust_remote_code=True,
dtype="bfloat16",
seed=42,
)
# Cache config, optionally force APC
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
enable_prefix_caching=True,
)
kv_transfer_config = KVTransferConfig(
kv_connector="MoRIIOConnector",
kv_role=role,
enable_permute_local_kv=enable_permute_local_kv,
)
return VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"),
)
@pytest.fixture
def moriio_read_mode():
"""Force the connector into read mode via env for tests."""
os.environ["VLLM_MORIIO_CONNECTOR_READ_MODE"] = "True"
yield
# Cleanup after test
os.environ.pop("VLLM_MORIIO_CONNECTOR_READ_MODE", None)
def test_write_mode_saves_local_block_ids():
"""Write mode records local block ids in MoRIIOConnectorMetadata.reqs_to_save."""
# Setup Scheduler and Request
vllm_config = create_vllm_config(role="kv_producer")
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
do_remote_prefill=False,
)
request_id = request.request_id
scheduler.add_request(request)
# Fake Config
request = _setup_kv_transfer_request(request)
# Remote Prefill, triggers MoRIIOConnectorMetadata.
scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None, "kv_connector_metadata is None"
assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata)
assert len(kv_connector_metadata.reqs_to_save) == 1, (
"Unexpected number of reqs_to_save"
)
assert len(kv_connector_metadata.reqs_to_recv) == 0, (
"Unexpected number of reqs_to_recv"
)
assert len(kv_connector_metadata.reqs_to_send) == 0, (
"Unexpected number of reqs_to_send"
)
assert request_id in kv_connector_metadata.reqs_to_save, (
"Request ID not in reqs_to_save"
)
req_meta = kv_connector_metadata.reqs_to_save[request_id]
for block_id, block in zip(
req_meta.local_block_ids,
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
request_id
],
):
assert block_id == block.block_id, f"{block_id} != {block.block_id}"
def test_write_mode_with_chunked_prefill_saves_local_block_ids():
"""Write mode with chunked prefill still records correct local block ids."""
# Setup Scheduler and Request
MAX_NUM_BATCHED_TOKENS = 64
NUM_TOKENS = MAX_NUM_BATCHED_TOKENS * 2 + MAX_NUM_BATCHED_TOKENS // 2
vllm_config = create_vllm_config(
max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_producer"
)
BLOCK_SIZE = vllm_config.cache_config.block_size
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
do_remote_prefill=False,
)
request_id = request.request_id
scheduler.add_request(request)
# Fake Config
request = _setup_kv_transfer_request(request)
# Remote Prefill with chunked prefill, triggers multiple schedules.
expected_counts = [(0, 0, 0), (0, 0, 0), (1, 0, 0)]
kv_connector_metadata = None
for _, (expected_save, expected_recv, expected_send) in enumerate(expected_counts):
scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert len(kv_connector_metadata.reqs_to_save) == expected_save
assert len(kv_connector_metadata.reqs_to_recv) == expected_recv
assert len(kv_connector_metadata.reqs_to_send) == expected_send
assert kv_connector_metadata is not None, "kv_connector_metadata is None"
assert request_id in kv_connector_metadata.reqs_to_save, (
"Request ID not in reqs_to_save"
)
req_meta = kv_connector_metadata.reqs_to_save[request_id]
for block_id, block in zip(
req_meta.local_block_ids,
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
request_id
],
):
assert block_id == block.block_id, f"{block_id} != {block.block_id}"
def test_read_mode_loads_remote_block_ids(moriio_read_mode):
"""Read mode loads remote block ids into local cache mapping."""
# Setup Scheduler and Request
vllm_config = create_vllm_config(role="kv_consumer")
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=False,
do_remote_prefill=True,
)
request_id = request.request_id
scheduler.add_request(request)
block_list = scheduler.kv_cache_manager.coordinator.single_type_managers[
0
].req_to_blocks[request_id]
request = _setup_kv_transfer_request(request)
# Set remote block ids to be fetched.
request.kv_transfer_params["remote_block_ids"] = block_list
# Remote Prefill, triggers MorIIOConnectorMetadata.
scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None, "kv_connector_metadata is None"
assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata), (
"kv_connector_metadata is not MoRIIOConnectorMetadata"
)
assert len(kv_connector_metadata.reqs_to_save) == 0, (
"Unexpected number of reqs_to_save"
)
assert len(kv_connector_metadata.reqs_to_recv) == 1, (
"Unexpected number of reqs_to_recv"
)
assert len(kv_connector_metadata.reqs_to_send) == 0, (
"Unexpected number of reqs_to_send"
)
assert request_id in kv_connector_metadata.reqs_to_recv, (
"Request ID not in reqs_to_recv"
)
req_meta = kv_connector_metadata.reqs_to_recv[request_id]
for block_id, block in zip(
req_meta.local_block_ids,
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
request_id
],
):
assert block_id == block.block_id, f"{block_id} != {block.block_id}"
@pytest.mark.skipif(
not aiter_available, reason="Requires aiter package for ROCm FlashAttention backend"
)
def test_register_kv_caches(mock_parallel_groups):
"""Test that MoRIIOConnector.register_kv_caches correctly registers kv caches."""
ROLE = "kv_consumer"
IP = get_ip()
vllm_config = create_vllm_config(role=ROLE)
DEFAULT_PORT = 6301
TP_RANK = 0
DP_RANK = 0
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
backend_cls = AiterFlashAttentionBackend
# Create test kv cache tensors using proper backend shape
kv_cache_shape = backend_cls.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
with (
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Event"
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Thread"
),
):
# Create connector
vllm_config.kv_transfer_config.kv_connector_extra_config.update(
{
"proxy_ip": "127.0.0.1",
"proxy_ping_port": 12345,
"http_port": 12346,
}
)
connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeMorIIOConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
from mori.io import (
MemoryDesc,
)
# Execute register_kv_caches
connector.register_kv_caches(kv_caches)
# Verify that the MemoryDesc stored in layer_name_to_local_kv_cache_metadata
assert (
shared_tensor.data_ptr()
== MemoryDesc.unpack(
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
"layer0"
][0]
).data
)
assert (
unique_tensor.data_ptr()
== MemoryDesc.unpack(
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
"layer1"
][0]
).data
)
assert (
shared_tensor.data_ptr()
== MemoryDesc.unpack(
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
"layer2"
][0]
).data
)
# Verify engine keys
expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}"
assert (
MemoryDesc.unpack(
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
"layer0"
][0]
).engine_key
== expected_engine_key
)
@pytest.mark.skipif(
not aiter_available, reason="Requires aiter package for ROCm FlashAttention backend"
)
def test_moriio_handshake_returns_metadata(mock_parallel_groups):
"""MoRIIO handshake socket returns valid agent metadata over ZMQ."""
ROLE = "kv_consumer"
vllm_config = create_vllm_config(role=ROLE)
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
backend_cls = AiterFlashAttentionBackend
# Create test kv cache tensors using proper backend shape
kv_cache_shape = backend_cls.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
with (
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper",
FakeMorIIOWrapper,
),
):
handshake_port = _find_free_port()
# Create connector
vllm_config.kv_transfer_config.kv_connector_extra_config.update(
{
"proxy_ip": "127.0.0.1",
"proxy_ping_port": 12345,
"http_port": 12346,
"handshake_port": handshake_port,
}
)
connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER)
# Execute register_kv_caches
connector.register_kv_caches(kv_caches)
# Connect to handshake socket and request metadata
path = make_zmq_path("tcp", "127.0.0.1", handshake_port)
with zmq_ctx(zmq.DEALER, path) as sock:
sock.send(MoRIIOConstants.GET_META_MSG)
received_frame = sock.recv_multipart()
if len(received_frame) != 2 or received_frame[0] != b"":
raise ValueError(f"Unexpected frame! {received_frame = }")
metadata_bytes = received_frame[1]
decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata)
metadata = decoder.decode(metadata_bytes)
assert isinstance(metadata, MoRIIOAgentMetadata), (
"Decoded metadata is not MoRIIOAgentMetadata"
)
......@@ -51,6 +51,33 @@ class MockConnector(KVConnectorBase_V1):
) -> KVConnectorStats | None:
return MockConnectorStats(data=data) if data is not None else None
def start_load_kv(self, forward_context, **kwargs):
pass
def wait_for_layer_load(self, layer_name):
pass
def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs):
pass
def wait_for_save(self):
pass
def build_connector_meta(self, scheduler_output):
return None
def get_num_new_matched_tokens(self, request, num_computed_tokens):
return (0, False)
def update_state_after_alloc(self, request, blocks, num_tokens) -> None:
pass
class MockCrossLayerConnector(MockConnector):
@property
def prefer_cross_layer_blocks(self) -> bool:
return True
# Register the mock connector
KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__)
......@@ -603,3 +630,21 @@ class TestMultiConnectorStats:
# One non-empty
stats.data["NixlConnector"].data["transfer_duration"].append(1.0)
assert not stats.is_empty()
class TestMultiConnectorPreferCrossLayerBlocks:
def test_all_connectors_prefer_cross_layer_blocks(self):
mc = MultiConnector.__new__(MultiConnector)
mc._connectors = [
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
]
assert mc.prefer_cross_layer_blocks is True
def test_mixed_connectors_do_not_prefer_cross_layer_blocks(self):
mc = MultiConnector.__new__(MultiConnector)
mc._connectors = [
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
MockConnector.__new__(MockConnector), # default False
]
assert mc.prefer_cross_layer_blocks is False
......@@ -41,10 +41,13 @@ from vllm.distributed.kv_transfer.kv_transfer_state import (
has_kv_transfer_group,
)
from vllm.forward_context import ForwardContext
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.platforms.interface import Platform
from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import RequestStatus
......@@ -182,18 +185,21 @@ class FakeNixlWrapper:
def _make_fake_nixl_pkg():
"""Context manager that creates a temporary package making
`from nixl._api import nixl_agent` resolve to our FakeNixlWrapper.
Also creates rixl package for ROCm compatibility.
Automatically cleans up the temporary directory when done.
"""
with tempfile.TemporaryDirectory() as td:
pkg_root = os.path.join(td, "nixl", "_api")
os.makedirs(pkg_root, exist_ok=True)
# Create both nixl and rixl packages for cross-platform compatibility
for pkg_name in ["nixl", "rixl"]:
pkg_root = os.path.join(td, pkg_name, "_api")
os.makedirs(pkg_root, exist_ok=True)
# Get the source code of FakeNixlWrapper class and dedent it
fake_nixl_source = inspect.getsource(FakeNixlWrapper)
fake_nixl_source = textwrap.dedent(fake_nixl_source)
# Get the source code of FakeNixlWrapper class and dedent it
fake_nixl_source = inspect.getsource(FakeNixlWrapper)
fake_nixl_source = textwrap.dedent(fake_nixl_source)
stub = f"""\
stub = f"""\
# Copy of FakeNixlWrapper implementation for Ray workers
import uuid
from collections import defaultdict
......@@ -203,16 +209,17 @@ from collections import defaultdict
# Export as nixl_agent
nixl_agent = FakeNixlWrapper
"""
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
f.write(stub)
# Mock nixlXferTelemetry class
pkg_root2 = os.path.join(td, "nixl", "_bindings")
os.makedirs(pkg_root2, exist_ok=True)
with open(os.path.join(pkg_root2, "__init__.py"), "w") as f:
f.write("class nixlXferTelemetry: pass")
# touch parent package
open(os.path.join(td, "nixl", "__init__.py"), "w").close()
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
f.write(stub)
# Mock nixlXferTelemetry class
pkg_root2 = os.path.join(td, pkg_name, "_bindings")
os.makedirs(pkg_root2, exist_ok=True)
with open(os.path.join(pkg_root2, "__init__.py"), "w") as f:
f.write("class nixlXferTelemetry: pass")
# touch parent package
open(os.path.join(td, pkg_name, "__init__.py"), "w").close()
yield td
......@@ -296,6 +303,7 @@ def test_prompt_less_than_block_size():
)
def test_kv_transfer_handshake(dist_init):
"""Unit test for basic NixlConnector interface functionality."""
from vllm.config import set_current_vllm_config
# Test setup, we creates a scheduler that contains a NixlConnector
# of role SCHEDULER, and expect it to be serving NixlAgentMetadata from
......@@ -305,81 +313,82 @@ def test_kv_transfer_handshake(dist_init):
vllm_config.kv_transfer_config.kv_buffer_device = "cpu"
scheduler = create_scheduler(vllm_config)
# Create two NixlConnector of role WORKER, one is the worker of
# the scheduler (prefill), the other is a worker of decode instance.
# Prefill connector will register KV cache to populate proper handshake
# metadata.
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
prefill_connector.register_kv_caches(kv_caches)
# Simulate EngineCore initialization that would gather connector
# metadata from all workers
metadata = prefill_connector.get_handshake_metadata()
# metadata is a NixlHandshakePayload, decode it to get NixlAgentMetadata
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
expected_agent_metadata = decoder.decode(metadata.agent_metadata_bytes)
# The scheduler connector expects metadata to be in
# dict[int, KVConnectorHandshakeMetadata], where the first key is
# the dp_rank, the second key is the tp_rank.
scheduler_connector = scheduler.get_kv_connector()
scheduler_connector.set_xfer_handshake_metadata({0: metadata})
with set_current_vllm_config(vllm_config):
# Create two NixlConnector of role WORKER, one is the worker of
# the scheduler (prefill), the other is a worker of decode instance.
# Simulate a request that finishes prefill, which returns
# corresponding NixlConnectorMetadata for decode instance.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
)
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
request, [0, 1, 2]
)
assert delay
# Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
# Here we are testing the retrieval of NIXLAgentMetadata.
# Knowing the implementation detail, we override the add_remote_agent
# to validate the metadata received is the same as the one in prefill_connector.
with patch.object(
decode_connector.connector_worker, "add_remote_agent"
) as mock_add_remote_agent:
mock_add_remote_agent.return_type = "remote_agent"
decode_connector.connector_worker._nixl_handshake(
kv_connector_metadata["remote_host"],
kv_connector_metadata["remote_port"],
kv_connector_metadata["tp_size"],
kv_connector_metadata["remote_engine_id"],
# Prefill connector will register KV cache to populate proper handshake
# metadata.
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
prefill_connector.register_kv_caches(kv_caches)
# Simulate EngineCore initialization that would gather connector
# metadata from all workers
metadata = prefill_connector.get_handshake_metadata()
# metadata is a NixlHandshakePayload, decode it to get NixlAgentMetadata
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
expected_agent_metadata = decoder.decode(metadata.agent_metadata_bytes)
# The scheduler connector expects metadata to be in
# dict[int, KVConnectorHandshakeMetadata], where the first key is
# the dp_rank, the second key is the tp_rank.
scheduler_connector = scheduler.get_kv_connector()
scheduler_connector.set_xfer_handshake_metadata({0: metadata})
# Simulate a request that finishes prefill, which returns
# corresponding NixlConnectorMetadata for decode instance.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
)
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
request, [0, 1, 2]
)
assert delay
# Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
# Here we are testing the retrieval of NIXLAgentMetadata.
# Knowing the implementation detail, we override the add_remote_agent
# to validate the metadata received is the same as the one in prefill_connector.
with patch.object(
decode_connector.connector_worker, "add_remote_agent"
) as mock_add_remote_agent:
mock_add_remote_agent.return_type = "remote_agent"
decode_connector.connector_worker._nixl_handshake(
kv_connector_metadata["remote_host"],
kv_connector_metadata["remote_port"],
kv_connector_metadata["tp_size"],
kv_connector_metadata["remote_engine_id"],
)
received_metadata = mock_add_remote_agent.call_args.args
assert received_metadata[0] == expected_agent_metadata
assert received_metadata[1] == 0 # remote_tp_rank
assert received_metadata[2] == 1 # remote_tp_size
received_metadata = mock_add_remote_agent.call_args.args
assert received_metadata[0] == expected_agent_metadata
assert received_metadata[1] == 0 # remote_tp_rank
assert received_metadata[2] == 1 # remote_tp_size
# Need to shutdown the background thread to release NIXL side channel port
scheduler_connector.shutdown()
# Need to shutdown the background thread to release NIXL side channel port
scheduler_connector.shutdown()
class FakeNixlConnectorWorker(NixlConnectorWorker):
......@@ -391,6 +400,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
super().__init__(*args, **kwargs)
self._hand_shake_latency = hand_shake_latency
self.kv_cache_layout = kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
self.src_xfer_handles_by_block_size = {self.block_size: 1}
def _nixl_handshake(
self, host: str, port: int, remote_tp_size: int, expected_engine_id: str
......@@ -407,22 +418,43 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
assert expected_engine_id == self.REMOTE_ENGINE_ID
remote_agent_name = self.add_remote_agent(
NixlAgentMetadata(
engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=0,
num_blocks=1,
block_lens=self.block_len_per_layer,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout="HND",
block_size=self.block_size,
),
remote_tp_size=remote_tp_size,
)
return {0: remote_agent_name}
# Adjust remote block length metadata to satisfy heterogeneous TP
# invariants enforced during handshake validation.
remote_block_lens = list(self.block_len_per_layer)
tp_ratio = self.kv_topo.tp_ratio(remote_tp_size)
if remote_tp_size > self.world_size:
# P TP > D TP case, block_len of remote is smaller
remote_block_lens = [
block_len // (-tp_ratio) for block_len in remote_block_lens
]
elif remote_tp_size < self.world_size:
remote_block_lens = [
block_len * tp_ratio for block_len in remote_block_lens
]
# When remote tp_size > local tp_size, handshake with multiple
# remote ranks.
num_hanshakes = 1 if tp_ratio > 0 else -tp_ratio
remote_agents: dict[int, str] = {}
for remote_tp_rank in range(num_hanshakes):
remote_agent_name = self.add_remote_agent(
NixlAgentMetadata(
engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=remote_tp_rank,
num_blocks=1,
block_lens=remote_block_lens,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout="HND",
block_size=self.block_size,
),
remote_tp_rank=remote_tp_rank,
remote_tp_size=remote_tp_size,
)
remote_agents[remote_tp_rank] = remote_agent_name
return remote_agents
class TestNixlHandshake:
......@@ -432,6 +464,7 @@ class TestNixlHandshake:
)
def test_multi_xfer_one_engine(
self,
default_vllm_config,
# dist_init is a fixture that initializes the distributed environment.
dist_init,
):
......@@ -453,7 +486,13 @@ class TestNixlHandshake:
vllm_config, connector.engine_id, hand_shake_latency=0
)
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
worker = connector.connector_worker
worker.nixl_wrapper.set_cycles_before_xfer_done(3)
# simulate handshake
worker.dst_xfer_side_handles = {
FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1}
}
worker.kv_cache_layout = "HND"
num_xfers = 4
while True:
# For the same request_id, initiate multiple xfers across different
......@@ -515,6 +554,7 @@ class TestNixlHandshake:
)
def test_async_load_kv(
self,
default_vllm_config,
# Fixture that initializes the distributed environment.
dist_init,
# Simulate consumer-producer TP sizes.
......@@ -567,12 +607,178 @@ class TestNixlHandshake:
return
raise TimeoutError("Took too long to complete async handshake.")
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size(
self, local_tp_size: int, default_vllm_config, dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations.
"""
vllm_config = create_vllm_config()
local_tp_size = 1
vllm_config.parallel_config.tensor_parallel_size = local_tp_size
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
worker = connector.connector_worker
# Minimal local registration params used by add_remote_agent
worker.slot_size_per_layer = [4096]
worker.block_len_per_layer = [4096 * worker.block_size]
worker.num_blocks = 1
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
worker.src_blocks_data = [(0, worker.block_len_per_layer[0], worker.tp_rank)]
def check_handshake(remote_tp_size: int):
tp_ratio = remote_tp_size // local_tp_size
assert set(remote_agents.keys()) == set(range(tp_ratio))
remote_engine_id = worker.REMOTE_ENGINE_ID
assert worker._tp_size[remote_engine_id] == remote_tp_size
assert -tp_ratio == worker.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
# ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks
assert -tp_ratio in worker.src_xfer_handles_by_tp_ratio
assert len(worker.src_xfer_handles_by_tp_ratio[-tp_ratio]) == tp_ratio
assert remote_engine_id in worker.dst_xfer_side_handles
assert set(worker.dst_xfer_side_handles[remote_engine_id].keys()) == set(
range(tp_ratio)
)
remote_agents = worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=2,
expected_engine_id=worker.REMOTE_ENGINE_ID,
)
check_handshake(2)
# NOTE flexiblity: a second remote with higher number of ranks is
# discovered. This is not a scenario we actively support right now, but
# the connector allows it.
worker.REMOTE_ENGINE_ID = "remote_engine_2"
remote_agents = worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=6,
expected_engine_id=worker.REMOTE_ENGINE_ID,
)
check_handshake(6)
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size_mla(
self, local_tp_size: int, default_vllm_config, dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations for an MLA model.
"""
vllm_config = create_vllm_config()
d_tp_size = 1
p_tp_size = 2
# Build two separate connectors/workers to emulate P TP=2 ranks.
conn_p0 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
conn_p1 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
conn_p0.connector_worker = FakeNixlConnectorWorker(
vllm_config, conn_p0.engine_id, hand_shake_latency=0
)
conn_p1.connector_worker = FakeNixlConnectorWorker(
vllm_config, conn_p1.engine_id, hand_shake_latency=0
)
# Force P world size to 2 for both workers and emulate distinct tp_ranks.
# Also enable MLA path so that expected_finished_count is updated.
for rank, worker in enumerate(
(conn_p0.connector_worker, conn_p1.connector_worker)
):
worker.world_size = p_tp_size
worker.kv_topo.remote_tp_size = {worker.engine_id: p_tp_size}
worker.tp_rank = rank
worker.use_mla = True
req_id = "req-ep-dp2-p0"
now = time.perf_counter()
# Register a request on P that is waiting for consumers to read
# (both workers track it).
conn_p0.connector_worker._reqs_to_send[req_id] = now + 10.0
conn_p0.connector_worker._reqs_to_process.add(req_id)
conn_p1.connector_worker._reqs_to_send[req_id] = now + 10.0
conn_p1.connector_worker._reqs_to_process.add(req_id)
# Simulate a read notification coming from D with (tp=1, dp=2).
notif = f"{req_id}:{d_tp_size}".encode()
# D0-0->P0 notif
conn_p0.connector_worker.nixl_wrapper.get_new_notifs = lambda: {
"agent": [notif]
} # type: ignore[method-assign]
conn_p1.connector_worker.nixl_wrapper.get_new_notifs = lambda: {
"agent": [notif]
} # type: ignore[method-assign]
# Trigger notification processing via get_finished().
done_sending0, _ = conn_p0.get_finished(finished_req_ids=set())
done_sending1, _ = conn_p1.get_finished(finished_req_ids=set())
assert req_id in done_sending0 and req_id in done_sending1
# E2E aggregation: ensure the aggregated output marks the request
# as finished using the connector's expected_finished_count.
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
aggregator = KVOutputAggregator.from_connector(conn_p0, world_size=2)
out0 = ModelRunnerOutput(
req_ids=[req_id],
req_id_to_index={req_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=done_sending0,
finished_recving=None,
),
)
out1 = ModelRunnerOutput(
req_ids=[req_id],
req_id_to_index={req_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=done_sending1,
finished_recving=None,
),
)
aggregated = aggregator.aggregate([out0, out1], output_rank=0)
assert aggregated.kv_connector_output is not None
assert aggregated.kv_connector_output.finished_sending == {req_id}
# Producers cleaned up state for the finished request.
assert req_id not in conn_p0.connector_worker._reqs_to_send
assert req_id not in conn_p0.connector_worker._reqs_to_process
assert req_id not in conn_p1.connector_worker._reqs_to_send
assert req_id not in conn_p1.connector_worker._reqs_to_process
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_concurrent_load_kv(
self,
default_vllm_config,
# dist_init is a fixture that initializes the distributed environment.
dist_init,
):
......@@ -585,6 +791,9 @@ class TestNixlHandshake:
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id
)
# Register (mocked) local xfer handler
# worker = connector.connector_worker
# worker.src_xfer_handles_by_block_size = {worker.block_size: 1}
metadata = NixlConnectorMetadata()
total_reqs = 5
for i in range(total_reqs):
......@@ -630,7 +839,9 @@ class TestNixlHandshake:
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
def test_handshake_fails_on_kv_cache_layout_mismatch(
self, default_vllm_config, dist_init
):
"""
Verify that adding a remote agent fails if kv_cache_layout differs.
This test is only relevant for heterogeneous TP.
......@@ -672,7 +883,6 @@ class TestNixlHandshake:
with pytest.raises(RuntimeError):
# mismatched layout is expected to fail
worker.add_remote_agent(meta, remote_tp_size=2)
with pytest.raises(AssertionError):
worker.add_remote_agent(meta, remote_tp_size=1)
@patch(
......@@ -680,7 +890,7 @@ class TestNixlHandshake:
FakeNixlWrapper,
)
def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
self, dist_init
self, default_vllm_config, dist_init
):
"""
Verify that adding a remote agent fails if kv_cache_layout differs.
......@@ -735,7 +945,7 @@ class TestNixlHandshake:
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_kv_connector_stats(dist_init):
def test_kv_connector_stats(default_vllm_config, dist_init):
"""Test that KV transfer stats are properly recorded and retrieved."""
vllm_config = create_vllm_config()
......@@ -1069,6 +1279,22 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
run_test_and_cleanup()
class RequestIdMapper:
"""Helper class to map external request IDs to internal request IDs."""
def __init__(self, output_processor: OutputProcessor):
self.req_id_mapping: dict[str, str] = {}
self.original_add_request = output_processor.add_request
output_processor.add_request = self._add_request
def _add_request(self, request: EngineCoreRequest, *args, **kwargs):
self.req_id_mapping[request.external_req_id] = request.request_id
return self.original_add_request(request, *args, **kwargs)
def __call__(self, external_req_id: str) -> str:
return self.req_id_mapping[external_req_id]
def _run_abort_timeout_test(llm: LLM, timeout: int):
"""Helper function to run the abort timeout test logic."""
remote_prefill_opts = {
......@@ -1090,24 +1316,34 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
0
].req_to_blocks
id_mapper = RequestIdMapper(llm.llm_engine.output_processor)
def req_id(outputs: list[RequestOutput]) -> str:
assert len(outputs) == 1
return id_mapper(outputs[0].request_id)
padding = "Just making this request a little longer so that we're sure "
"we're not hitting the small-request lower bound beneath which we don't "
"actually trigger the whole kv transfer, but rather just recompute the "
"blocks on D."
_ = llm.generate([f"What is the capital of Japan? {padding}"], sampling_params)
req0_id = req_id(
llm.generate([f"What is the capital of Japan? {padding}"], sampling_params)
)
# Request finished but not freed
assert "0" in scheduler.finished_req_ids and "0" in req_to_blocks
assert req0_id in scheduler.finished_req_ids and req0_id in req_to_blocks
# Some other request, 0 still not freed
_ = llm.generate([f"What is the capital of Italy? {padding}"], sampling_params)
assert "0" in req_to_blocks
assert "1" in scheduler.finished_req_ids and "1" in req_to_blocks
req1_id = req_id(
llm.generate([f"What is the capital of Italy? {padding}"], sampling_params)
)
assert req0_id in req_to_blocks
assert req1_id in scheduler.finished_req_ids and req1_id in req_to_blocks
# Wait for timeout and trigger another scheduler loop
time.sleep(timeout)
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
# Request-0 times out and is cleared!
assert "0" not in req_to_blocks
assert req0_id not in req_to_blocks
# Need to shutdown the background thread to release NIXL side channel port
llm.llm_engine.engine_core.shutdown()
......@@ -1132,7 +1368,7 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
"TRITON_ATTN",
],
)
def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
"""
Test that register_kv_caches() properly calls nixl_wrapper methods with
correct data.
......@@ -1144,9 +1380,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
block layout info
"""
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
vllm_config = create_vllm_config()
vllm_config = create_vllm_config(attention_backend=attn_backend)
# Import the appropriate backend based on the parameter
if attn_backend == "FLASH_ATTN":
......@@ -1205,7 +1439,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper,
patch(f"{nixl_module}.threading.Event"),
patch(f"{nixl_module}.threading.Thread") as mock_thread,
patch(f"{nixl_module}.get_attn_backend") as mock_get_attn_backend,
patch(f"{nixl_module}.get_current_attn_backend") as mock_get_attn_backend,
):
# Ensure get_attn_backend returns the correct value due to
# _cached_get_attn_backend returning the backend from previous
......@@ -1295,7 +1529,9 @@ class FakePlatform(Platform):
("oot", "VRAM"),
],
)
def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device, nixl_memory_type):
def test_kv_buffer_to_nixl_memory_types(
default_vllm_config, dist_init, kv_buffer_device, nixl_memory_type
):
"""
Test that register_kv_caches() passes the correct memory types from the
config to the nixl_wrapper.
......@@ -1340,7 +1576,7 @@ def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device, nixl_memory
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_shutdown_cleans_up_resources(dist_init):
def test_shutdown_cleans_up_resources(default_vllm_config, dist_init):
"""Test that shutdown() properly cleans up all resources."""
vllm_config = create_vllm_config()
......@@ -1359,8 +1595,11 @@ def test_shutdown_cleans_up_resources(dist_init):
patch.object(nixl_wrapper, "deregister_memory") as mock_dereg,
):
worker._recving_transfers = {"req1": [123]}
worker.src_xfer_side_handle = 456
worker.dst_xfer_side_handles = {"engine1": 789}
# Mock register_kv_cache which registers local handle
worker.src_xfer_handles_by_block_size = {worker.block_size: 455}
# P TP = 2 * D TP case, we should register 2 local handles
worker.src_xfer_handles_by_tp_ratio = {-2: [456, 457]}
worker.dst_xfer_side_handles = {"engine1": {0: 789}}
worker._remote_agents = {"engine1": {0: "agent1"}}
worker._registered_descs = ["desc1", "desc2"]
......@@ -1381,8 +1620,10 @@ def test_shutdown_cleans_up_resources(dist_init):
mock_listener.join.assert_called_once()
mock_rel_xfer.assert_called_once_with(123)
assert mock_rel_dlist.call_count == 2
mock_rel_dlist.assert_any_call(456) # src handle
assert mock_rel_dlist.call_count == 4
mock_rel_dlist.assert_any_call(455) # src handle (whole region)
mock_rel_dlist.assert_any_call(456) # src handle (1st chunk)
mock_rel_dlist.assert_any_call(457) # src handle (2nd chunk)
mock_rel_dlist.assert_any_call(789) # dst handle
mock_rem_agent.assert_called_once_with("agent1")
assert mock_dereg.call_count == 2
......@@ -1394,7 +1635,7 @@ def test_shutdown_cleans_up_resources(dist_init):
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_aborted_request_removed_from_worker_in_batch(dist_init):
def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_init):
"""
Create and schedule a request so that P adds it to in-batch tracking via
the real scheduler, then simulate an abort (request not in next scheduler
......@@ -1464,6 +1705,8 @@ class FailingNixlWrapper(FakeNixlWrapper):
self.fail_handshake = False
self.fail_transfer_setup = False
self.fail_send_notif = False
self.fail_transfer_state = False # Returns "ERR" state
self.fail_transfer_exception = False # Raises exception in check_xfer_state
def add_remote_agent(self, agent_metadata: bytes) -> str:
if self.fail_handshake:
......@@ -1498,12 +1741,156 @@ class FailingNixlWrapper(FakeNixlWrapper):
raise RuntimeError("Simulated send_notif failure")
return super().send_notif(agent_name, notif_msg)
def check_xfer_state(self, handle: int) -> str:
if self.fail_transfer_exception:
raise RuntimeError("Simulated check_xfer_state exception")
if self.fail_transfer_state:
return "ERR" # Bad transfer state
return super().check_xfer_state(handle)
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FailingNixlWrapper,
)
@pytest.mark.parametrize(
"failure_type,wrapper_config,needs_get_finished",
[
("transfer_setup_failed", {"fail_transfer_setup": True}, False),
("handshake_failed", {"fail_handshake": True}, False),
("notification_failed", {"fail_send_notif": True}, False),
("transfer_failed", {"fail_transfer_state": True}, True),
("transfer_exception", {"fail_transfer_exception": True}, True),
],
)
def test_transfer_failure_logging(
default_vllm_config,
dist_init,
failure_type,
wrapper_config,
needs_get_finished,
):
"""Test that transfer failures are logged with structured context.
Run with `pytest -sv` to see the log output.
Covers failure types:
- transfer_setup_failed: make_prepped_xfer fails
- handshake_failed: add_remote_agent fails during request handshake
- notification_failed: send_notif fails
- transfer_failed: check_xfer_state returns bad state (e.g., "ERR")
- transfer_exception: check_xfer_state raises exception
"""
import logging
vllm_config = create_vllm_config()
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0.0
)
# Configure FailingNixlWrapper to fail in the specified way
for key, value in wrapper_config.items():
setattr(connector.connector_worker.nixl_wrapper, key, value)
request_id = f"test_{failure_type}_req"
# For notification_failed, we need empty local blocks
# (full cache hit path to trigger send_notif)
local_blocks = [] if failure_type == "notification_failed" else [10, 11, 12]
remote_blocks = [20, 21, 22]
metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv(
request_id=request_id,
local_block_ids=local_blocks,
kv_transfer_params={
"remote_block_ids": remote_blocks,
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
},
)
connector.bind_connector_metadata(metadata)
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
# Capture logs from the nixl_connector logger specifically
# vLLM loggers have propagate=False, so we need to capture directly
nixl_logger = logging.getLogger(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector"
)
captured_logs: list[logging.LogRecord] = []
class LogCapture(logging.Handler):
def emit(self, record):
captured_logs.append(record)
handler = LogCapture()
handler.setLevel(logging.ERROR)
nixl_logger.addHandler(handler)
try:
connector.start_load_kv(dummy_ctx)
# Process the ready_requests queue (for async handshake)
connector.bind_connector_metadata(NixlConnectorMetadata())
# Wait for async handshake to complete
time.sleep(0.2)
connector.start_load_kv(dummy_ctx)
# For transfer_failed/transfer_exception, the error happens in
# get_finished() when checking transfer state
if needs_get_finished:
connector.get_finished(finished_req_ids=set())
finally:
nixl_logger.removeHandler(handler)
# Print logs for manual comparison between commits
error_logs = [r for r in captured_logs if r.levelno >= logging.ERROR]
print("\n" + "=" * 60)
print(f"CAPTURED ERROR LOGS for {failure_type}:")
print("=" * 60)
for i, record in enumerate(error_logs):
print(f"\n--- Log {i + 1} ---")
print(f"Message: {record.message}")
print("=" * 60 + "\n")
assert len(error_logs) >= 1, f"Expected at least one error log for {failure_type}"
# Verify structured logging output (new format)
# Check that at least one log matches the expected format
all_messages = [r.message for r in error_logs]
combined_logs = "\n".join(all_messages)
assert any("NIXL transfer failure" in msg for msg in all_messages), (
f"Expected structured log format with 'NIXL transfer failure' prefix "
f"for {failure_type}. Got: {all_messages}"
)
assert any("failure_type" in msg for msg in all_messages), (
f"Expected 'failure_type' in logs. Got: {all_messages}"
)
assert any("Context:" in msg for msg in all_messages), (
f"Expected 'Context:' in logs. Got: {all_messages}"
)
# Check that the expected failure_type appears in at least one log
# Note: handshake_failed also triggers handshake_setup_failed
assert failure_type in combined_logs or (
failure_type == "handshake_failed" and "handshake_setup_failed" in combined_logs
), f"Expected '{failure_type}' in logs. Got: {all_messages}"
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FailingNixlWrapper,
)
def test_handshake_failure_returns_finished(dist_init):
def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
"""Test that handshake failures mark blocks invalid and return via get_finished."""
vllm_config = create_vllm_config()
......@@ -1552,7 +1939,7 @@ def test_handshake_failure_returns_finished(dist_init):
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FailingNixlWrapper,
)
def test_transfer_setup_failure_returns_finished(dist_init):
def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init):
"""Test that transfer setup failures mark blocks invalid
and return via get_finished."""
vllm_config = create_vllm_config()
......@@ -1627,6 +2014,7 @@ def test_transfer_setup_failure_returns_finished(dist_init):
FakeNixlWrapper,
)
def test_compatibility_hash_validation(
default_vllm_config,
dist_init,
mismatch_type,
config_overrides,
......@@ -1739,7 +2127,7 @@ def test_compatibility_hash_validation(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_handshake_decode_errors(dist_init, error_scenario):
def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario):
"""
Test that msgspec decode errors are properly handled during handshake.
......
......@@ -26,6 +26,7 @@ from vllm.v1.core.kv_cache_utils import (
init_none_hash,
)
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.kv_offload.abstract import (
LoadStoreSpec,
OffloadingEvent,
......@@ -64,8 +65,11 @@ class MockLoadStoreSpec(LoadStoreSpec):
class MockOffloadingHandler(OffloadingHandler):
def __init__(self):
self.transfer_specs: dict[int, TransferSpec] = {}
self.completed_transfers: list[TransferResult] = []
self.completed_specs: list[TransferSpec] = []
self.waiting_jobs: set[int] = set()
self.completed_jobs: list[int] = []
self.flushed_jobs: set[int] = set()
def get_finished(self) -> list[TransferResult]:
finished = self.completed_transfers
......@@ -73,14 +77,25 @@ class MockOffloadingHandler(OffloadingHandler):
return finished
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
self.completed_specs.append(spec)
self.completed_transfers.append((job_id, True))
self.transfer_specs[job_id] = spec
self.waiting_jobs.add(job_id)
return True
def complete_jobs(self, job_ids: set[int]) -> None:
for job_id in job_ids:
if job_id in self.waiting_jobs:
self.waiting_jobs.remove(job_id)
self.completed_jobs.append(job_id)
self.completed_transfers.append((job_id, True))
def wait(self, job_ids: set[int]) -> None:
self.flushed_jobs |= job_ids
self.complete_jobs(job_ids)
class MockOffloadingSpec(OffloadingSpec):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig):
super().__init__(vllm_config, kv_cache_config)
self.manager = MagicMock(spec=OffloadingManager)
self.manager.lookup.return_value = 0
......@@ -98,9 +113,22 @@ class MockOffloadingSpec(OffloadingSpec):
yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler
yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler
def complete_transfers(self):
self.handler.complete_jobs(self.handler.waiting_jobs.copy())
def get_completed_transfers(self) -> list[TransferSpec]:
specs = self.handler.completed_specs
self.handler.completed_specs = []
specs = [
self.handler.transfer_specs[job_id]
for job_id in self.handler.completed_jobs
]
self.handler.completed_jobs.clear()
return specs
def get_flushed_transfers(self):
specs = [
self.handler.transfer_specs[job_id] for job_id in self.handler.flushed_jobs
]
self.handler.flushed_jobs.clear()
return specs
......@@ -170,11 +198,9 @@ class RequestRunner:
# mapping (offloading address) -> gpu_block_index
self.offloaded: dict[Any, int] = {}
self.pending_loads_count: int = 0
self.pending_stores_count: int = 0
self.completed_loads: list[TransferSummary] = []
self.completed_stores: list[TransferSummary] = []
self.flushed_gpu_block_indexes: set[int] = set()
# maps {block_id: block_offset}
self.gpu_block_index: dict[int, int] = {}
......@@ -201,54 +227,60 @@ class RequestRunner:
self.scheduler.add_request(req)
def _wait_for_transfers(self):
def _parse_transfers(self):
for transfer_spec in self.offloading_spec.get_flushed_transfers():
src_spec, dst_spec = transfer_spec
assert isinstance(src_spec, GPULoadStoreSpec)
for block_id in src_spec.block_ids:
self.flushed_gpu_block_indexes.add(
self.gpu_block_index[block_id.item()]
)
block_size_factor = self.offloaded_block_size // self.gpu_block_size
while self.pending_loads_count or self.pending_stores_count:
for transfer_spec in self.offloading_spec.get_completed_transfers():
src_spec, dst_spec = transfer_spec
if isinstance(src_spec, GPULoadStoreSpec):
store = True
gpu_spec = src_spec
offload_spec = dst_spec
else:
store = False
gpu_spec = dst_spec
offload_spec = src_spec
assert isinstance(offload_spec, MockLoadStoreSpec)
assert isinstance(gpu_spec, GPULoadStoreSpec)
gpu_block_indices: list[int] = []
for block_id in gpu_spec.block_ids:
gpu_block_indices.append(self.gpu_block_index[block_id.item()])
# list of (block_hash, sub_block_offset)
offload_addresses: list[Any] = []
for block_hash in offload_spec.block_hashes:
for sub_block_idx in range(block_size_factor):
offload_addresses.append((block_hash, sub_block_idx))
if store:
assert len(gpu_block_indices) == len(offload_addresses)
self.completed_stores.append(
TransferSummary(gpu_block_indices, offload_addresses)
)
self.pending_stores_count -= 1
else:
remainder_sub_block_count = len(offload_addresses) - len(
gpu_block_indices
)
assert remainder_sub_block_count >= 0
assert remainder_sub_block_count < block_size_factor
offload_addresses = offload_addresses[remainder_sub_block_count:]
self.completed_loads.append(
TransferSummary(gpu_block_indices, offload_addresses)
)
self.pending_loads_count -= 1
for transfer_spec in self.offloading_spec.get_completed_transfers():
src_spec, dst_spec = transfer_spec
if isinstance(src_spec, GPULoadStoreSpec):
store = True
gpu_spec = src_spec
offload_spec = dst_spec
else:
store = False
gpu_spec = dst_spec
offload_spec = src_spec
assert isinstance(offload_spec, MockLoadStoreSpec)
assert isinstance(gpu_spec, GPULoadStoreSpec)
gpu_block_indices: list[int] = []
for block_id in gpu_spec.block_ids:
gpu_block_indices.append(self.gpu_block_index[block_id.item()])
# list of (block_hash, sub_block_offset)
offload_addresses: list[Any] = []
for block_hash in offload_spec.block_hashes:
for sub_block_idx in range(block_size_factor):
offload_addresses.append((block_hash, sub_block_idx))
if store:
assert len(gpu_block_indices) == len(offload_addresses)
self.completed_stores.append(
TransferSummary(gpu_block_indices, offload_addresses)
)
else:
remainder_sub_block_count = len(offload_addresses) - len(
gpu_block_indices
)
assert remainder_sub_block_count >= 0
assert remainder_sub_block_count < block_size_factor
offload_addresses = offload_addresses[remainder_sub_block_count:]
self.completed_loads.append(
TransferSummary(gpu_block_indices, offload_addresses)
)
def _update_gpu_block_idx(self):
for blocks in self.scheduler.kv_cache_manager.coordinator.single_type_managers[
......@@ -257,18 +289,19 @@ class RequestRunner:
for block_idx, block in enumerate(blocks):
self.gpu_block_index[block.block_id] = block_idx
def _run(self, decoded_tokens: list[int]):
def _run(self, decoded_tokens: list[int], complete_transfers: bool):
"""
Runs multiple engine (scheduler + worker) steps.
Assumes a single request is running.
Args:
decoded_tokens: the tokens to yield at each step.
complete_transfers: complete transfers immediately
"""
tokens_iter = iter(decoded_tokens)
token_id = next(tokens_iter, None)
while token_id is not None:
while True:
assert self.scheduler.requests
scheduler_output = self.scheduler.schedule()
......@@ -278,8 +311,10 @@ class RequestRunner:
assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)
self.pending_loads_count += len(kv_connector_metadata.reqs_to_load)
self.pending_stores_count += len(kv_connector_metadata.reqs_to_store)
if scheduler_output.preempted_req_ids:
self.worker_connector.handle_preemptions(
scheduler_output.preempted_req_ids
)
self.worker_connector.bind_connector_metadata(kv_connector_metadata)
self.worker_connector.start_load_kv(self._dummy_ctx)
......@@ -287,6 +322,9 @@ class RequestRunner:
if scheduler_output.total_num_scheduled_tokens > 0:
self.worker_connector.wait_for_save()
if complete_transfers:
self.offloading_spec.complete_transfers()
finished_sending, finished_recving = self.worker_connector.get_finished(
scheduler_output.finished_req_ids
)
......@@ -297,7 +335,7 @@ class RequestRunner:
reqs=self.scheduler.running,
finished_sending=finished_sending,
finished_recving=finished_recving,
token_id=token_id,
token_id=token_id or 0,
)
if self.scheduler.running:
......@@ -305,7 +343,10 @@ class RequestRunner:
self.scheduler.update_from_output(scheduler_output, model_runner_output)
self._wait_for_transfers()
if token_id is None:
break
self._parse_transfers()
# run one more step to update finished stored
if EOS_TOKEN_ID in decoded_tokens:
......@@ -330,8 +371,10 @@ class RequestRunner:
def run(
self,
decoded_tokens: list[int],
complete_transfers: bool = True,
expected_stored_gpu_block_indexes: tuple[int, ...] = (),
expected_loaded_gpu_block_indexes: tuple[int, ...] = (),
expected_flushed_gpu_block_indexes: tuple[int, ...] = (),
):
"""
Runs multiple engine (scheduler + worker) steps.
......@@ -339,14 +382,17 @@ class RequestRunner:
Args:
decoded_tokens: the tokens to yield at each step.
complete_transfers: complete transfers immediately
expected_stored_gpu_block_indexes: GPU block indexes
that are expected to be written during the run.
expected_loaded_gpu_block_indexes: GPU block indexes
that are expected to be loaded during the run.
expected_flushed_gpu_block_indexes: GPU block indexes
that are expected to be flushed during the run.
"""
self.manager.reset_mock()
self._run(decoded_tokens)
self._run(decoded_tokens, complete_transfers)
loaded_gpu_block_indexes: set[int] = set()
for transfer in self.completed_loads:
......@@ -370,6 +416,9 @@ class RequestRunner:
assert set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes
self.completed_stores.clear()
assert set(expected_flushed_gpu_block_indexes) == self.flushed_gpu_block_indexes
self.flushed_gpu_block_indexes.clear()
@pytest.fixture
def request_runner():
......@@ -414,10 +463,13 @@ def test_offloading_connector(request_runner):
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(list(block_hashes)[1:2])
)
runner.run(decoded_tokens=[0], expected_stored_gpu_block_indexes=(3, 4, 5))
runner.run(decoded_tokens=[0])
# add block missing 1 token -> no offload
runner.run(decoded_tokens=[0] * (offloaded_block_size - 1))
runner.run(
decoded_tokens=[0] * (offloaded_block_size - 1),
expected_stored_gpu_block_indexes=(3, 4, 5),
)
runner.manager.prepare_store.assert_not_called()
# +1 token -> single block, fail prepare_store
......@@ -435,23 +487,20 @@ def test_offloading_connector(request_runner):
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
)
runner.run(
decoded_tokens=[0] * offloaded_block_size,
expected_stored_gpu_block_indexes=(15, 16, 17),
)
runner.run(decoded_tokens=[0] * offloaded_block_size)
runner.manager.touch.assert_called()
block_hashes1 = list(runner.manager.touch.call_args.args[0])
assert len(block_hashes1) == 6
# terminate request
runner.run(decoded_tokens=[EOS_TOKEN_ID])
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=(15, 16, 17),
)
# create a new request differing only on the last token
runner.new_request(token_ids=[0] * (offloaded_block_size * 6 - 1) + [1])
runner.run(
decoded_tokens=[0],
expected_stored_gpu_block_indexes=tuple(range(6 * block_size_factor)),
)
runner.run(decoded_tokens=[0])
runner.manager.touch.assert_called()
block_hashes2 = list(runner.manager.touch.call_args.args[0])
assert len(block_hashes2) == 6
......@@ -461,7 +510,10 @@ def test_offloading_connector(request_runner):
assert block_hashes1[5] != block_hashes2[5]
# terminate request
runner.run(decoded_tokens=[EOS_TOKEN_ID])
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=tuple(range(6 * block_size_factor)),
)
# full_block_tokens - num_computed_tokens < offloaded_block_size
runner.new_request(
......@@ -528,7 +580,74 @@ def test_offloading_connector(request_runner):
assert event.token_ids == []
assert event.parent_block_hash is None
assert event.lora_id is None
assert event.lora_name is None
event = events[1]
assert isinstance(event, BlockRemoved)
assert event.block_hashes == to_hashes([4, 5, 6])
assert event.medium == "B"
def test_request_preemption(request_runner):
offloaded_block_size = 12
gpu_block_size = 4
num_gpu_blocks = 100
runner = request_runner(
offloaded_block_size=offloaded_block_size,
gpu_block_size=gpu_block_size,
num_gpu_blocks=num_gpu_blocks,
)
free_block_queue = runner.scheduler.kv_cache_manager.block_pool.free_block_queue
num_free_blocks_empty = free_block_queue.num_free_blocks
# 2 blocks, store all, without flushing
# blocks = [0, 1, 2], [3, 4, 5]
runner.new_request(token_ids=[0] * offloaded_block_size * 2)
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
)
runner.run(
decoded_tokens=[0],
complete_transfers=False,
)
# decode 2 more blocks - 1 gpu block, storing [6, 7, 8] (no flush)
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
)
runner.run(
decoded_tokens=[0] * (2 * offloaded_block_size - gpu_block_size),
complete_transfers=False,
)
# simulate KV cache running out of space
free_block_queue.num_free_blocks = 0
# request should be preempted now
runner.run(
decoded_tokens=[],
complete_transfers=False,
expected_flushed_gpu_block_indexes=(0, 1, 2, 3, 4, 5, 6, 7, 8),
expected_stored_gpu_block_indexes=(0, 1, 2, 3, 4, 5, 6, 7, 8),
)
# restore KV cache space and reset GPU prefix cache
free_block_queue.num_free_blocks = num_free_blocks_empty
runner.scheduler.reset_prefix_cache()
# request should now return from preemption
# re-load [0, ..., 8] from the CPU and store [9, 10, 11]
runner.manager.lookup.return_value = 3
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
)
runner.run(
decoded_tokens=[0] * gpu_block_size,
expected_loaded_gpu_block_indexes=(0, 1, 2, 3, 4, 5, 6, 7, 8),
)
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=(9, 10, 11),
)
......@@ -11,6 +11,7 @@ import torch
from vllm import SamplingParams
from vllm.config import (
AttentionConfig,
CacheConfig,
DeviceConfig,
KVTransferConfig,
......@@ -94,6 +95,7 @@ def create_vllm_config(
dtype: str = "float16",
cache_dtype: str = "auto",
hf_overrides: dict[str, Any] | None = None,
attention_backend: str | None = None,
) -> VllmConfig:
"""Initialize VllmConfig For Testing."""
model_config = ModelConfig(
......@@ -131,12 +133,14 @@ def create_vllm_config(
enable_permute_local_kv=enable_permute_local_kv,
kv_connector_extra_config=kv_connector_extra_config or {},
)
attention_config = AttentionConfig(backend=attention_backend)
return VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"),
attention_config=attention_config,
)
......@@ -151,7 +155,13 @@ def create_scheduler(
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
["layer"],
FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
)
],
)
......
......@@ -7,6 +7,7 @@ import pytest
import torch
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandlers
......@@ -49,6 +50,7 @@ NUM_MAPPINGS = [3]
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_transfer(
default_vllm_config,
gpu_to_cpu: bool,
num_mappings: int,
head_size: int,
......@@ -62,7 +64,7 @@ def test_transfer(
seed: int,
device: str,
) -> None:
current_platform.seed_everything(seed)
set_random_seed(seed)
# create per-layer GPU KV caches based on available attn_backends
attn_backends_list = BACKENDS_TO_TEST
......
......@@ -13,13 +13,12 @@ from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import KVEventsConfig, KVTransferConfig
from vllm.distributed.kv_events import BlockStored, KVEventBatch
from vllm.platforms import current_platform
from vllm.utils.system_utils import set_env_var
CPU_BLOCK_SIZES = [48]
ATTN_BACKENDS = ["FLASH_ATTN"]
ATTN_BACKENDS = []
if current_platform.is_cuda():
ATTN_BACKENDS.append("FLASHINFER")
ATTN_BACKENDS = ["FLASH_ATTN", "FLASHINFER", "TRITON_ATTN"]
elif current_platform.is_rocm():
ATTN_BACKENDS = ["TRITON_ATTN"]
......@@ -162,7 +161,7 @@ def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None:
kv_connector="OffloadingConnector",
kv_role="kv_both",
kv_connector_extra_config={
"num_cpu_blocks": 1000,
"cpu_bytes_to_use": 500 << 20,
"block_size": cpu_block_size,
},
)
......@@ -180,13 +179,13 @@ def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None:
topic="test",
)
with set_env_var("VLLM_ATTENTION_BACKEND", attn_backend):
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
gpu_memory_utilization=0.5,
kv_events_config=kv_events_config,
kv_transfer_config=kv_transfer_config,
)
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
gpu_memory_utilization=0.5,
kv_events_config=kv_events_config,
kv_transfer_config=kv_transfer_config,
attention_config={"backend": attn_backend},
)
events_endpoint = events_endpoint.replace("*", "127.0.0.1")
subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic)
......
......@@ -63,6 +63,12 @@ class OffloadingHandler1To2(OffloadingHandler):
del self.transfers[job_id]
return finished
def wait(self, job_ids: set[int]) -> None:
for job_id in job_ids:
spec = self.transfers.get(job_id)
if spec:
assert spec.finished
class OffloadingHandler2To1(OffloadingHandler):
def __init__(self):
......@@ -84,6 +90,12 @@ class OffloadingHandler2To1(OffloadingHandler):
del self.transfers[job_id]
return finished
def wait(self, job_ids: set[int]) -> None:
for job_id in job_ids:
spec = self.transfers.get(job_id)
if spec:
assert spec.finished
def test_offloading_worker():
"""
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for the analytic estimators in metrics/flops.py.
"""
import types
from types import SimpleNamespace
from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config
from transformers.models.llama4.configuration_llama4 import (
Llama4Config,
Llama4TextConfig,
)
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig
from vllm.config.model import ModelConfig, get_hf_text_config
from vllm.transformers_utils.model_arch_config_convertor import (
MODEL_ARCH_CONFIG_CONVERTORS,
ModelArchConfigConvertorBase,
)
from vllm.v1.metrics.perf import (
AttentionMetrics,
BaseConfigParser,
ExecutionContext,
FfnMetrics,
ModelMetrics,
ParsedArgs,
UnembedMetrics,
)
class MockModelConfig:
"""Mock ModelConfig that implements the getter methods used by parsers."""
def __init__(self, hf_config, dtype):
self.hf_config = hf_config
self.hf_text_config = get_hf_text_config(hf_config)
convertor_cls = MODEL_ARCH_CONFIG_CONVERTORS.get(
self.hf_config.model_type, ModelArchConfigConvertorBase
)
self.model_arch_config = convertor_cls(
self.hf_config, self.hf_text_config
).convert()
self.dtype = dtype
self.is_attention_free = False
def __getattr__(self, name):
# 1. Check if ModelConfig actually has this attribute
if not hasattr(ModelConfig, name):
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}' "
f"and neither does 'ModelConfig'."
)
# 2. Fetch the attribute from the ModelConfig CLASS
attr = getattr(ModelConfig, name)
# 3. Case A: It is a @property
if isinstance(attr, property):
# Manually invoke the property's getter, passing 'self' (this mock instance)
return attr.__get__(self, self.__class__)
# 4. Case B: It is a standard method (function)
if isinstance(attr, types.FunctionType):
# Bind the function to 'self' so it acts like a method of
# this instance. This creates a bound method where 'self' is
# automatically passed as the first arg.
return types.MethodType(attr, self)
# 5. Case C: It is a class attribute / static variable
return attr
def create_mock_vllm_config(
hf_config,
model_dtype="bfloat16",
cache_dtype="auto",
quant_config=None,
data_parallel_size=1,
tensor_parallel_size=1,
pipeline_parallel_size=1,
enable_expert_parallel=False,
) -> SimpleNamespace:
vllm_config = SimpleNamespace()
vllm_config.model_config = MockModelConfig(hf_config, model_dtype)
vllm_config.cache_config = SimpleNamespace()
vllm_config.cache_config.cache_dtype = cache_dtype
vllm_config.quant_config = quant_config
vllm_config.parallel_config = SimpleNamespace()
vllm_config.parallel_config.data_parallel_size = data_parallel_size
vllm_config.parallel_config.tensor_parallel_size = tensor_parallel_size
vllm_config.parallel_config.pipeline_parallel_size = pipeline_parallel_size
vllm_config.parallel_config.enable_expert_parallel = enable_expert_parallel
return vllm_config
#### Parser Tests ####
def test_base_config_parser():
"""Test BaseConfigParser extracts base model attributes correctly."""
hf_config = Qwen3Config(
vocab_size=50000,
hidden_size=2048,
num_attention_heads=16,
num_hidden_layers=24,
)
vllm_config = create_mock_vllm_config(hf_config, model_dtype="float16")
parser = BaseConfigParser()
args = ParsedArgs()
result = parser.parse(args, vllm_config)
assert result.vocab_size == 50000
assert result.hidden_size == 2048
assert result.num_attention_heads == 16
assert result.num_hidden_layers == 24
assert result.weight_byte_size == 2 # float16 is 2 bytes
assert result.activation_byte_size == 2 # default activation size
def test_base_attention_config_parser_with_gqa():
"""Test BaseAttentionConfigParser with grouped query attention."""
hf_config = Qwen3Config(
hidden_size=4096,
num_attention_heads=32,
num_key_value_heads=8, # GQA with 4:1 ratio
head_dim=128,
)
vllm_config = create_mock_vllm_config(hf_config)
parser_chain = AttentionMetrics.get_parser()
result = parser_chain.parse(vllm_config)
assert result.num_key_value_heads == 8
assert result.head_dim == 128
def test_base_attention_config_parser_without_gqa():
"""
Test BaseAttentionConfigParser defaults to MHA when num_key_value_heads not
specified.
"""
hf_config = Qwen3Config(
hidden_size=4096,
num_attention_heads=32,
# No num_key_value_heads specified
)
vllm_config = create_mock_vllm_config(hf_config)
parser_chain = AttentionMetrics.get_parser()
result = parser_chain.parse(vllm_config)
# Should default to MHA (num_key_value_heads = num_attention_heads)
assert result.num_key_value_heads == 32
def test_base_ffn_config_parser_dense():
"""Test BaseFfnConfigParser for dense FFN."""
hf_config = Qwen3Config(
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
)
vllm_config = create_mock_vllm_config(hf_config)
parser_chain = FfnMetrics.get_parser()
result = parser_chain.parse(vllm_config)
assert result.intermediate_size == 11008
assert result.num_experts == 0
assert result.num_experts_per_tok == 0
assert result.num_moe_layers == 0 # No MoE
def test_base_ffn_config_parser_moe():
"""Test BaseFfnConfigParser for MoE FFN."""
hf_config = Qwen3MoeConfig(
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_experts=64,
num_experts_per_tok=8,
moe_intermediate_size=14336,
n_shared_experts=2,
)
vllm_config = create_mock_vllm_config(hf_config)
parser_chain = FfnMetrics.get_parser()
result = parser_chain.parse(vllm_config)
assert result.num_experts == 64
assert result.num_experts_per_tok == 8
assert result.moe_intermediate_size == 14336
assert result.num_shared_experts == 2
assert result.num_moe_layers == 32 # All layers are MoE by default
def test_interleave_moe_layer_step_parser():
"""Test InterleaveMoeLayerStepParser correctly computes MoE layer count."""
hf_config = Llama4Config(
text_config=Llama4TextConfig(
num_hidden_layers=32,
num_local_experts=64,
interleave_moe_layer_step=4, # Every 4th layer is MoE
),
)
vllm_config = create_mock_vllm_config(hf_config)
parser_chain = FfnMetrics.get_parser()
result = parser_chain.parse(vllm_config)
assert result.num_moe_layers == 8
def test_moe_layer_freq_parser():
"""Test MoeLayerFreqParser correctly computes MoE layer count."""
hf_config = DeepseekV3Config(
num_hidden_layers=30,
n_routed_experts=64,
moe_layer_freq=3, # Every 3rd layer after first_k_dense_replace
first_k_dense_replace=6, # First 6 layers are dense
)
vllm_config = create_mock_vllm_config(hf_config)
parser_chain = FfnMetrics.get_parser()
result = parser_chain.parse(vllm_config)
# Layers >= 6 and divisible by 3: 6, 9, 12, 15, 18, 21, 24, 27
expected_moe_layers = len(
[layer for layer in range(30) if layer >= 6 and layer % 3 == 0]
)
assert expected_moe_layers == 8
assert result.num_moe_layers == expected_moe_layers
#### ComponentMetrics Tests ####
def test_attention_metrics_scaling():
"""Test that attention metrics scale proportionally with model dimensions."""
base_hf_config = Qwen3Config(
hidden_size=2048,
num_attention_heads=16,
num_key_value_heads=16,
num_hidden_layers=12,
head_dim=128,
)
base_vllm_config = create_mock_vllm_config(base_hf_config)
base_metrics = AttentionMetrics.from_vllm_config(base_vllm_config)
# Test scaling with number of layers
double_layers_hf_config = Qwen3Config(
hidden_size=2048,
num_attention_heads=16,
num_key_value_heads=16,
num_hidden_layers=24, # Double the layers
head_dim=128,
)
double_layers_vllm_config = create_mock_vllm_config(double_layers_hf_config)
double_layers_metrics = AttentionMetrics.from_vllm_config(double_layers_vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# FLOPS should double when layers double
base_flops = base_metrics.get_num_flops(ctx)
double_flops = double_layers_metrics.get_num_flops(ctx)
assert double_flops == 2 * base_flops
# Read/write bytes should also scale proportionally
base_read = base_metrics.get_read_bytes(ctx)
double_read = double_layers_metrics.get_read_bytes(ctx)
assert double_read == 2 * base_read
base_write = base_metrics.get_write_bytes(ctx)
double_write = double_layers_metrics.get_write_bytes(ctx)
assert double_write == 2 * base_write
def test_attention_metrics_grouped_query():
"""Test attention metrics handle grouped query attention correctly."""
mha_hf_config = Qwen3Config(
hidden_size=4096,
num_attention_heads=32,
num_key_value_heads=32, # MHA
num_hidden_layers=1,
)
mha_config = create_mock_vllm_config(mha_hf_config)
gqa_hf_config = Qwen3Config(
hidden_size=4096,
num_attention_heads=32,
num_key_value_heads=8, # GQA with 4:1 ratio
num_hidden_layers=1,
)
gqa_config = create_mock_vllm_config(gqa_hf_config)
mha_metrics = AttentionMetrics.from_vllm_config(mha_config)
gqa_metrics = AttentionMetrics.from_vllm_config(gqa_config)
ctx = ExecutionContext.from_single_request(
num_tokens=1, context_len=1024, is_prefill=False
)
# GQA should have less KV cache reads since fewer KV heads
mha_read = mha_metrics.get_read_bytes(ctx)
gqa_read = gqa_metrics.get_read_bytes(ctx)
assert gqa_read < mha_read
def test_ffn_metrics_scaling():
"""Test FFN metrics scale proportionally with model dimensions."""
base_hf_config = Qwen3Config(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=12,
)
base_vllm_config = create_mock_vllm_config(base_hf_config)
base_metrics = FfnMetrics.from_vllm_config(base_vllm_config)
# Test scaling with intermediate size
larger_ffn_hf_config = Qwen3Config(
hidden_size=2048,
intermediate_size=16384, # Double intermediate size
num_hidden_layers=12,
)
larger_ffn_vllm_config = create_mock_vllm_config(larger_ffn_hf_config)
larger_ffn_metrics = FfnMetrics.from_vllm_config(larger_ffn_vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# FLOPS should double when intermediate size doubles
base_flops = base_metrics.get_num_flops(ctx)
larger_flops = larger_ffn_metrics.get_num_flops(ctx)
assert larger_flops == base_flops * 2
def test_moe_metrics_vs_dense():
"""Test MoE metrics versus dense metrics."""
dense_hf_config = Qwen3Config(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=12,
)
dense_config = create_mock_vllm_config(dense_hf_config)
moe_hf_config = Qwen3MoeConfig(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=12,
num_experts=64,
num_experts_per_tok=2, # 2 routed expert
moe_intermediate_size=8192,
n_shared_experts=0,
)
moe_config = create_mock_vllm_config(moe_hf_config)
dense_metrics = FfnMetrics.from_vllm_config(dense_config)
moe_metrics = FfnMetrics.from_vllm_config(moe_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# MoE should have different compute/memory characteristics
dense_flops = dense_metrics.get_num_flops(ctx)
moe_flops = moe_metrics.get_num_flops(ctx)
# 2 routed experts vs 1 dense.
assert moe_flops == dense_flops * 2
def test_unembed_metrics_scaling():
"""Test unembedding metrics scale with vocab size."""
small_vocab_hf_config = Qwen3Config(
hidden_size=2048,
vocab_size=32000,
)
small_vocab_config = create_mock_vllm_config(small_vocab_hf_config)
large_vocab_hf_config = Qwen3Config(
hidden_size=2048,
vocab_size=64000, # Double vocab size
)
large_vocab_config = create_mock_vllm_config(large_vocab_hf_config)
small_vocab_metrics = UnembedMetrics.from_vllm_config(small_vocab_config)
large_vocab_metrics = UnembedMetrics.from_vllm_config(large_vocab_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# FLOPS should double when vocab size doubles
small_flops = small_vocab_metrics.get_num_flops(ctx)
large_flops = large_vocab_metrics.get_num_flops(ctx)
assert large_flops == 2 * small_flops
def test_prefill_vs_decode_differences():
"""Test that prefill and decode have different memory access patterns."""
hf_config = Qwen3Config(
hidden_size=2048,
num_attention_heads=16,
num_key_value_heads=16,
num_hidden_layers=1,
)
config = create_mock_vllm_config(hf_config)
metrics = AttentionMetrics.from_vllm_config(config)
prefill_ctx = ExecutionContext.from_single_request(
num_tokens=512, context_len=512, is_prefill=True
)
decode_ctx = ExecutionContext.from_single_request(
num_tokens=1, context_len=512, is_prefill=False
)
prefill_read = metrics.get_read_bytes(prefill_ctx)
decode_read = metrics.get_read_bytes(decode_ctx)
assert prefill_read != decode_read
def test_model_metrics_aggregation():
"""Test ModelMetrics correctly aggregates across components."""
hf_config = Qwen3Config(
hidden_size=2048,
num_attention_heads=16,
num_hidden_layers=12,
vocab_size=32000,
intermediate_size=8192,
)
config = create_mock_vllm_config(hf_config)
model_metrics = ModelMetrics(config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# Should have metrics for attention, ffn, and unembed
total_flops = model_metrics.get_num_flops(ctx)
breakdown = model_metrics.get_num_flops_breakdown(ctx)
# Breakdown should sum to total
assert total_flops == sum(breakdown.values())
def test_moe_expert_activation_proportional_scaling():
"""Test that routed expert metrics scale proportionally with num_experts_per_tok."""
base_moe_config = Qwen3MoeConfig(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=12,
num_experts=64,
num_experts_per_tok=1, # 1 expert per token
moe_intermediate_size=8192,
n_shared_experts=2,
)
double_experts_config = Qwen3MoeConfig(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=12,
num_experts=64,
num_experts_per_tok=2, # 2 experts per token (double)
moe_intermediate_size=8192,
n_shared_experts=2, # Same shared experts
)
triple_experts_config = Qwen3MoeConfig(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=12,
num_experts=64,
num_experts_per_tok=3, # 3 experts per token (triple)
moe_intermediate_size=8192,
n_shared_experts=2, # Same shared experts
)
base_vllm_config = create_mock_vllm_config(base_moe_config)
double_vllm_config = create_mock_vllm_config(double_experts_config)
triple_vllm_config = create_mock_vllm_config(triple_experts_config)
base_metrics = FfnMetrics.from_vllm_config(base_vllm_config)
double_metrics = FfnMetrics.from_vllm_config(double_vllm_config)
triple_metrics = FfnMetrics.from_vllm_config(triple_vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# Get total metrics - the key insight is that differences should be proportional
base_flops = base_metrics.get_num_flops(ctx)
double_flops = double_metrics.get_num_flops(ctx)
triple_flops = triple_metrics.get_num_flops(ctx)
# The difference between double and base should equal one additional expert
one_expert_diff = double_flops - base_flops
# The difference between triple and base should equal two additional experts
two_expert_diff = triple_flops - base_flops
# Proportional scaling: 2 * (1 expert diff) should equal (2 expert diff)
assert two_expert_diff == 2 * one_expert_diff
# Same logic applies to memory operations
base_read = base_metrics.get_read_bytes(ctx)
double_read = double_metrics.get_read_bytes(ctx)
triple_read = triple_metrics.get_read_bytes(ctx)
one_expert_read_diff = double_read - base_read
two_expert_read_diff = triple_read - base_read
assert two_expert_read_diff == 2 * one_expert_read_diff
# Same for write bytes
base_write = base_metrics.get_write_bytes(ctx)
double_write = double_metrics.get_write_bytes(ctx)
triple_write = triple_metrics.get_write_bytes(ctx)
one_expert_write_diff = double_write - base_write
two_expert_write_diff = triple_write - base_write
assert two_expert_write_diff == 2 * one_expert_write_diff
def test_quantization_config_parser_fp8():
"""Test quantization parsers with fp8."""
class MockQuantConfig:
def get_name(self):
return "fp8"
hf_config = Qwen3Config(
hidden_size=2048, num_attention_heads=16, num_hidden_layers=1
)
vllm_config = create_mock_vllm_config(hf_config, quant_config=MockQuantConfig())
attn_result = AttentionMetrics.get_parser().parse(vllm_config)
assert attn_result.weight_byte_size == 1 # fp8
ffn_result = FfnMetrics.get_parser().parse(vllm_config)
assert ffn_result.weight_byte_size == 1 # fp8
def test_quantization_config_parser_mxfp4():
"""Test quantization parsers with mxfp4."""
class MockQuantConfig:
def get_name(self):
return "mxfp4"
hf_config = Qwen3Config(
hidden_size=2048, intermediate_size=8192, num_hidden_layers=1
)
vllm_config = create_mock_vllm_config(hf_config, quant_config=MockQuantConfig())
ffn_result = FfnMetrics.get_parser().parse(vllm_config)
assert ffn_result.weight_byte_size == 0.5 # mxfp4
#### Per-GPU Tests ####
def test_attention_per_gpu_with_tensor_parallelism():
"""Test attention metrics with tensor parallelism - per_gpu vs global."""
hf_config = Qwen3Config(
hidden_size=4096,
num_attention_heads=32,
num_key_value_heads=8,
num_hidden_layers=24,
)
# Test with TP=4
vllm_config = create_mock_vllm_config(hf_config, tensor_parallel_size=4)
metrics = AttentionMetrics.from_vllm_config(vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=128, context_len=1024, is_prefill=True
)
# Get global and per-gpu metrics
global_flops = metrics.get_num_flops(ctx, per_gpu=False)
per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True)
# With TP=4, global flops should be 4x per-gpu flops (heads divided by 4)
assert global_flops == 4 * per_gpu_flops
# Same for read/write bytes
global_read = metrics.get_read_bytes(ctx, per_gpu=False)
per_gpu_read = metrics.get_read_bytes(ctx, per_gpu=True)
# Reads should scale similarly (weight reads are divided by TP)
assert global_read > per_gpu_read
global_write = metrics.get_write_bytes(ctx, per_gpu=False)
per_gpu_write = metrics.get_write_bytes(ctx, per_gpu=True)
assert global_write > per_gpu_write
def test_attention_per_gpu_with_pipeline_parallelism():
"""Test attention metrics with pipeline parallelism - per_gpu vs global."""
hf_config = Qwen3Config(
hidden_size=2048,
num_attention_heads=16,
num_hidden_layers=32,
)
# Test with PP=4
vllm_config = create_mock_vllm_config(hf_config, pipeline_parallel_size=4)
metrics = AttentionMetrics.from_vllm_config(vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=False
)
# Get global and per-gpu metrics
global_flops = metrics.get_num_flops(ctx, per_gpu=False)
per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True)
# With PP=4, global flops should be 4x per-gpu flops (layers divided by 4)
assert global_flops == 4 * per_gpu_flops
global_read = metrics.get_read_bytes(ctx, per_gpu=False)
per_gpu_read = metrics.get_read_bytes(ctx, per_gpu=True)
assert global_read == 4 * per_gpu_read
def test_ffn_per_gpu_with_tensor_parallelism():
"""Test FFN metrics with tensor parallelism - per_gpu vs global."""
hf_config = Qwen3Config(
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
)
# Test with DP=2, TP=4 (ffn_tp_size will be 8)
vllm_config = create_mock_vllm_config(
hf_config,
data_parallel_size=2,
tensor_parallel_size=4,
)
metrics = FfnMetrics.from_vllm_config(vllm_config)
# ffn_tp_size should be dp_size * tp_size = 8 (when EP not enabled)
assert metrics.ffn_tp_size == 8
ctx = ExecutionContext.from_single_request(
num_tokens=128, context_len=2048, is_prefill=True
)
# Get global and per-gpu metrics
global_flops = metrics.get_num_flops(ctx, per_gpu=False)
per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True)
# With ffn_tp_size=8, global should be 8x per-gpu
assert global_flops == 8 * per_gpu_flops
def test_ffn_per_gpu_with_pipeline_parallelism():
"""Test FFN metrics with pipeline parallelism - per_gpu vs global."""
hf_config = Qwen3Config(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=24,
)
# Test with PP=6
vllm_config = create_mock_vllm_config(hf_config, pipeline_parallel_size=6)
metrics = FfnMetrics.from_vllm_config(vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# Get global and per-gpu metrics
global_flops = metrics.get_num_flops(ctx, per_gpu=False)
per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True)
# With PP=6, global should be 6x per-gpu (layers divided by 6)
assert global_flops == 6 * per_gpu_flops
def test_moe_per_gpu_with_expert_parallelism():
"""
Test MoE metrics with expert parallelism - verifies num_activated_experts bug fix.
"""
hf_config = Qwen3MoeConfig(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=24,
num_experts=64,
num_experts_per_tok=8,
moe_intermediate_size=14336,
n_shared_experts=2,
)
# Test with DP=2, TP=4, EP enabled (ffn_ep_size will be 8)
vllm_config = create_mock_vllm_config(
hf_config,
data_parallel_size=2,
tensor_parallel_size=4,
enable_expert_parallel=True,
)
metrics = FfnMetrics.from_vllm_config(vllm_config)
# When EP enabled, ffn_ep_size = dp_size * tp_size = 8
assert metrics.ffn_ep_size == 8
assert metrics.ffn_tp_size == 1
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# Get per-gpu metrics
per_gpu_read_breakdown = metrics.get_read_bytes_breakdown(ctx, per_gpu=True)
global_read_breakdown = metrics.get_read_bytes_breakdown(ctx, per_gpu=False)
# Verify that routed expert weight reads are reasonable
# With per_gpu=True, each GPU has 64/8 = 8 experts
# T=100, E_per_gpu=8/8=1, so T*E=100 expert activations
# num_activated_experts should be min(100, 8) = 8
# Check that weight reads scale appropriately
# Global has all 64 experts, per-gpu has 8 experts
# So weight reads should reflect this difference
if "routed_up_gate_weights" in per_gpu_read_breakdown:
per_gpu_weight_reads = per_gpu_read_breakdown["routed_up_gate_weights"]
global_weight_reads = global_read_breakdown["routed_up_gate_weights"]
# The ratio should reflect the expert count difference
# This verifies the bug fix works correctly
assert per_gpu_weight_reads < global_weight_reads
# Global should read more experts than per-gpu
# Exact ratio depends on num_activated_experts calculation
ratio = global_weight_reads / per_gpu_weight_reads
# Should be > 1 since global has more experts to read
assert ratio > 1
def test_moe_per_gpu_expert_activation_accounting():
"""
Test that MoE correctly accounts for expert activations with small batch sizes.
"""
hf_config = Qwen3MoeConfig(
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=12,
num_experts=64,
num_experts_per_tok=8,
moe_intermediate_size=14336,
n_shared_experts=0, # No shared experts for this test
)
# Test with EP=8
vllm_config = create_mock_vllm_config(
hf_config,
data_parallel_size=8,
enable_expert_parallel=True,
)
metrics = FfnMetrics.from_vllm_config(vllm_config)
# Small batch: T=10, E_per_gpu=8/8=1
# Each GPU: T*E = 10*1 = 10 activations
# Experts per GPU: 64/8 = 8
# So num_activated_experts should be min(10, 8) = 8
small_ctx = ExecutionContext.from_single_request(
num_tokens=10, context_len=512, is_prefill=True
)
small_read = metrics.get_read_bytes_breakdown(small_ctx, per_gpu=True)
# Large batch: T=1000, E_per_gpu=1
# Each GPU: T*E = 1000*1 = 1000 activations
# Experts per GPU: 8
# So num_activated_experts should be min(1000, 8) = 8 (all experts activated)
large_ctx = ExecutionContext.from_single_request(
num_tokens=1000, context_len=512, is_prefill=True
)
large_read = metrics.get_read_bytes_breakdown(large_ctx, per_gpu=True)
# Weight reads should be similar (both activate all 8 experts per GPU)
# But activation reads should differ (proportional to T*E)
if "routed_up_gate_weights" in small_read:
small_weight = small_read["routed_up_gate_weights"]
large_weight = large_read["routed_up_gate_weights"]
# Weight reads should be the same (both read all 8 experts)
assert small_weight == large_weight
# But input activation reads should scale with T*E
small_input = small_read["routed_up_gate_input"]
large_input = large_read["routed_up_gate_input"]
assert large_input == 100 * small_input # 1000/10 = 100x
def test_unembed_per_gpu_with_tensor_parallelism():
"""Test unembed metrics with tensor parallelism - per_gpu vs global."""
hf_config = Qwen3Config(
hidden_size=4096,
vocab_size=128000,
)
# Test with TP=8
vllm_config = create_mock_vllm_config(hf_config, tensor_parallel_size=8)
metrics = UnembedMetrics.from_vllm_config(vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# Get global and per-gpu metrics
global_flops = metrics.get_num_flops(ctx, per_gpu=False)
per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True)
# With TP=8, vocab is divided by 8, so global should be 8x per-gpu
assert global_flops == 8 * per_gpu_flops
# For read bytes, weight reads scale with TP but input reads don't (replicated)
global_read_breakdown = metrics.get_read_bytes_breakdown(ctx, per_gpu=False)
per_gpu_read_breakdown = metrics.get_read_bytes_breakdown(ctx, per_gpu=True)
# Input reads should be the same (replicated across TP ranks)
assert global_read_breakdown["input"] == per_gpu_read_breakdown["input"]
# Weight reads should scale 8x (divided by TP)
assert global_read_breakdown["weight"] == 8 * per_gpu_read_breakdown["weight"]
def test_model_metrics_per_gpu_aggregation():
"""Test ModelMetrics correctly aggregates per_gpu metrics across components."""
hf_config = Qwen3Config(
hidden_size=2048,
num_attention_heads=16,
num_hidden_layers=12,
vocab_size=32000,
intermediate_size=8192,
)
# Test with mixed parallelism: TP=2, PP=2
vllm_config = create_mock_vllm_config(
hf_config,
tensor_parallel_size=2,
pipeline_parallel_size=2,
)
model_metrics = ModelMetrics(vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=100, context_len=512, is_prefill=True
)
# Get breakdowns for both modes
per_gpu_breakdown = model_metrics.get_num_flops_breakdown(ctx, per_gpu=True)
global_breakdown = model_metrics.get_num_flops_breakdown(ctx, per_gpu=False)
# Verify breakdown sums match totals
per_gpu_total = model_metrics.get_num_flops(ctx, per_gpu=True)
global_total = model_metrics.get_num_flops(ctx, per_gpu=False)
assert per_gpu_total == sum(per_gpu_breakdown.values())
assert global_total == sum(global_breakdown.values())
# Global should be larger than per-gpu due to parallelism
assert global_total > per_gpu_total
# With TP=2 and PP=2, the ratio depends on which parallelism applies to
# which component but we can verify that global is reasonably larger
ratio = global_total / per_gpu_total
assert ratio > 1 # Should be between PP and TP*PP depending on component mix
def test_attention_per_gpu_heads_not_evenly_divisible():
"""Test attention with heads not evenly divisible by TP."""
hf_config = Qwen3Config(
hidden_size=2048,
num_attention_heads=17, # Not divisible by 4
num_key_value_heads=5, # Not divisible by 4
num_hidden_layers=8,
)
vllm_config = create_mock_vllm_config(hf_config, tensor_parallel_size=4)
metrics = AttentionMetrics.from_vllm_config(vllm_config)
ctx = ExecutionContext.from_single_request(
num_tokens=64, context_len=256, is_prefill=True
)
# Should not crash and should handle max(1, ...) correctly
per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True)
global_flops = metrics.get_num_flops(ctx, per_gpu=False)
# Both should be positive
assert per_gpu_flops > 0
assert global_flops > 0
assert global_flops > per_gpu_flops
......@@ -516,6 +516,424 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
del llm
class TestCorrectDecodedToken:
"""Unit tests for _correct_decoded_token method in LogprobsProcessor.
This method handles UTF-8 decoding issues where incomplete byte sequences
result in the Unicode replacement character "�" (U+FFFD). This commonly
happens with byte-fallback tokenization when multi-byte UTF-8 characters
are split across tokens.
"""
@pytest.fixture
def mock_tokenizer(self):
"""Create a mock tokenizer for testing."""
from unittest.mock import Mock
tokenizer = Mock()
return tokenizer
@pytest.fixture
def processor_with_empty_logprobs(self, mock_tokenizer):
"""Create a LogprobsProcessor with empty logprobs."""
from vllm.v1.engine.logprobs import LogprobsProcessor
processor = LogprobsProcessor(
tokenizer=mock_tokenizer,
logprobs=[],
prompt_logprobs=None,
cumulative_logprob=0.0,
num_logprobs=1,
num_prompt_logprobs=None,
)
return processor
@pytest.fixture
def processor_with_previous_logprobs(self, mock_tokenizer):
"""Create a LogprobsProcessor with previous logprobs."""
from vllm.v1.engine.logprobs import LogprobsProcessor
processor = LogprobsProcessor(
tokenizer=mock_tokenizer,
logprobs=[{123: None}], # Previous token ID is 123
prompt_logprobs=None,
cumulative_logprob=0.0,
num_logprobs=1,
num_prompt_logprobs=None,
)
return processor
def test_correction_with_previous_token_in_list(
self, processor_with_empty_logprobs
):
"""Test correction using previous token in the same list.
Scenario: Token at idx=1 ends with "�", but when decoded with
the previous token (idx=0), it forms a valid UTF-8 sequence.
Example: token[0]="�", token[1]="�" -> together form "polarized"
"""
processor = processor_with_empty_logprobs
tokens = [100, 101, 102] # token IDs
# Mock tokenizer behavior:
# - decode([102]) returns "�" (ends with replacement char)
# - decode([101, 102]) returns "valid" (no replacement char)
processor.tokenizer.decode.side_effect = lambda ids: (
"valid" if ids == [101, 102] else "�"
)
result = processor._correct_decoded_token(2, tokens)
assert result == "valid"
processor.tokenizer.decode.assert_called_with([101, 102])
def test_correction_with_previous_logprob_token(
self, processor_with_previous_logprobs
):
"""Test correction using previous logprob token.
Scenario: Cannot correct with previous token in list (idx=0),
but can correct with previous logprob token.
"""
processor = processor_with_previous_logprobs
tokens = [100] # single token
# Mock tokenizer behavior:
# - decode([100]) returns "�" (ends with replacement char)
# - decode([123, 100]) returns " "polarized" (no replacement char)
# Token 123 is from previous logprobs
def mock_decode(ids):
if ids == [123, 100]:
return ' "polarized"'
return "�"
processor.tokenizer.decode.side_effect = mock_decode
result = processor._correct_decoded_token(0, tokens)
assert result == ' "polarized"'
def test_correction_at_idx_zero_no_previous_logprobs(
self, processor_with_empty_logprobs
):
"""Test correction at idx=0 with no previous logprobs.
Scenario: First token in list, no previous logprobs available.
Should return empty string as fallback.
"""
processor = processor_with_empty_logprobs
tokens = [100]
# Mock tokenizer always returns "�"
processor.tokenizer.decode.return_value = "�"
result = processor._correct_decoded_token(0, tokens)
assert result == ""
def test_correction_at_idx_zero_with_previous_logprobs(
self, processor_with_previous_logprobs
):
"""Test correction at idx=0 with previous logprobs available.
Scenario: First token in list, but previous logprobs exist.
Should try correction with previous logprob token.
"""
processor = processor_with_previous_logprobs
tokens = [200]
# Mock tokenizer behavior
def mock_decode(ids):
if ids == [123, 200]:
return "corrected"
return "�"
processor.tokenizer.decode.side_effect = mock_decode
result = processor._correct_decoded_token(0, tokens)
assert result == "corrected"
def test_no_correction_needed_returns_fallback(
self, processor_with_previous_logprobs
):
"""Test fallback to empty string when no correction works.
Scenario: All correction attempts still end with "�".
Should return empty string as final fallback.
"""
processor = processor_with_previous_logprobs
tokens = [100, 101, 102]
# Mock tokenizer always returns text ending with "�"
processor.tokenizer.decode.return_value = "still�"
result = processor._correct_decoded_token(2, tokens)
assert result == ""
def test_middle_token_correction(self, processor_with_previous_logprobs):
"""Test correction for a token in the middle of the list.
Scenario: Token at idx=5 in a longer list needs correction.
"""
processor = processor_with_previous_logprobs
tokens = [10, 20, 30, 40, 50, 60, 70, 80]
# Mock tokenizer behavior for middle token
def mock_decode(ids):
if ids == [50, 60]:
return "olar"
return "�"
processor.tokenizer.decode.side_effect = mock_decode
result = processor._correct_decoded_token(5, tokens)
assert result == "olar"
def test_multiple_consecutive_replacement_chars(
self, processor_with_previous_logprobs
):
"""Test handling of multiple consecutive replacement characters.
Scenario: Sequence like ["�", "�", "p"] where first two should
become empty strings.
"""
processor = processor_with_previous_logprobs
# Test first replacement char
tokens = [100, 101, 102]
processor.tokenizer.decode.return_value = "still�"
result1 = processor._correct_decoded_token(0, tokens)
assert result1 == ""
# Test second replacement char
result2 = processor._correct_decoded_token(1, tokens)
assert result2 == ""
def test_correction_with_multibyte_utf8(self, processor_with_previous_logprobs):
"""Test correction involving multi-byte UTF-8 characters.
Scenario: Byte-fallback tokenization splits multi-byte UTF-8
characters (e.g., curly quotes, Chinese characters, emojis).
Example from user: "�", "�" -> "", "\""
"""
processor = processor_with_previous_logprobs
tokens = [200, 201]
# Mock tokenizer behavior for multi-byte UTF-8 correction
def mock_decode(ids):
# When decoding first token (idx=0) with previous logprob token
if ids == [123, 200]:
return ' "' # Space + left curly quote
# When decoding second token (idx=1) with previous token in list
elif ids == [200, 201]:
return '"' # Right curly quote
# When decoding second token (idx=1) with previous logprob + prev token
elif ids == [123, 200, 201]:
return ' ""' # Full sequence
return "�"
processor.tokenizer.decode.side_effect = mock_decode
# First token correction (idx=0)
# Will call decode([123, 200]) since idx=0 uses previous logprob token
result1 = processor._correct_decoded_token(0, tokens)
assert result1 == ' "'
# Second token correction (idx=1)
# Will call decode([200, 201]) since idx>0 uses previous token in list
result2 = processor._correct_decoded_token(1, tokens)
assert result2 == '"'
def test_real_world_opt125m_scenario(self, mock_tokenizer):
"""Test the real-world scenario from user's example.
User's example with facebook/opt-125m:
Before: [" the", " term", " �", "�", "p", "olar", "ized", "�", "�", ...]
After: [" the", " term", "", " "", "p", "olar", "ized", "", "\"", ...]
"""
from vllm.v1.engine.logprobs import LogprobsProcessor
# Simulate the sequence of tokens
processor = LogprobsProcessor(
tokenizer=mock_tokenizer,
logprobs=[],
prompt_logprobs=None,
cumulative_logprob=0.0,
num_logprobs=1,
num_prompt_logprobs=None,
)
# Token IDs representing the problematic sequence
tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9] # placeholder IDs
# Mock decode behavior simulating the real scenario
def mock_decode(ids):
# Simulate cases where individual tokens decode to "�"
# but combinations decode correctly
if len(ids) == 1:
if ids[0] == 3 or ids[0] == 4 or ids[0] == 8 or ids[0] == 9:
return "�"
elif len(ids) == 2:
if ids == [2, 3]:
return " term�" # Still ends with �, need more context
elif ids == [3, 4]:
return ' "' # Corrected to space + left curly quote
elif ids == [7, 8]:
return "ized�" # Still ends with �
elif ids == [8, 9]:
return '"' # Corrected to right curly quote
elif len(ids) == 3:
if ids == [1, 2, 3]:
return " the term�" # Still ends with issue
elif ids == [2, 3, 4]:
return ' term "' # With all context
return "normal_text"
mock_tokenizer.decode.side_effect = mock_decode
# Test token at index 2 (should fail to correct, return "")
# Token 3 individually is "�"
# decode([2, 3]) = " term�" (still ends with �)
# No previous logprobs, so fallback to ""
result = processor._correct_decoded_token(2, tokens)
assert result == ""
# Test token at index 3 (should correct to " "")
# Token 4 individually is "�"
# decode([3, 4]) = " "" (corrected!)
processor.logprobs = [{2: None}] # Add previous logprob
result = processor._correct_decoded_token(3, tokens)
assert result == ' "'
def test_verify_tokens_integration():
"""Integration test for _verify_tokens with real model.
This test validates that _verify_tokens correctly identifies and
corrects tokens ending with the replacement character "�".
Uses facebook/opt-125m which is known to produce these issues.
"""
runner = VllmRunner(
"facebook/opt-125m",
max_logprobs=0,
enable_prefix_caching=False,
gpu_memory_utilization=0.15,
max_model_len=256,
)
# Use a prompt that triggers multi-byte UTF-8 issues
# Based on user's example: "In this example,"
test_prompts = ["In this example,"]
sampling_params = SamplingParams(
max_tokens=16,
temperature=0,
logprobs=0,
)
results = runner.llm.generate(test_prompts, sampling_params=sampling_params)
# Verify that decoded tokens don't contain replacement characters
for result in results:
assert result.outputs[0].logprobs is not None
for logprob_dict in result.outputs[0].logprobs:
for token_id, logprob_info in logprob_dict.items():
decoded_token = logprob_info.decoded_token
# Decoded tokens should not end with replacement character
# They should either be corrected or empty string
assert not decoded_token.endswith("�"), (
f"Token {token_id} decoded to '{decoded_token}' which "
f"ends with replacement character"
)
# Decoded tokens should not contain lone replacement characters
assert decoded_token != "�", (
f"Token {token_id} is a lone replacement character"
)
def test_utf8_edge_cases_with_real_model():
"""Test various UTF-8 edge cases with a real model.
Tests prompts that are likely to trigger byte-fallback tokenization
and multi-byte UTF-8 splitting.
"""
runner = VllmRunner(
"facebook/opt-125m",
max_logprobs=1,
enable_prefix_caching=False,
gpu_memory_utilization=0.15,
max_model_len=256,
)
# Prompts with various multi-byte UTF-8 characters
test_prompts = [
'Smart quotes: "Hello"', # Curly quotes
"Em dash — test", # Em dash
"Ellipsis… continues", # Ellipsis
"Chinese: 你好", # Chinese characters
"Emoji: 😀 🎉", # Emojis
'Mixed: "quoted" — with symbols', # Mixed
]
sampling_params = SamplingParams(
max_tokens=10,
temperature=0,
logprobs=1,
)
results = runner.llm.generate(test_prompts, sampling_params=sampling_params)
for i, result in enumerate(results):
prompt = test_prompts[i]
assert result.outputs[0].logprobs is not None
# Check that no decoded tokens end with replacement character
for logprob_dict in result.outputs[0].logprobs:
for token_id, logprob_info in logprob_dict.items():
decoded_token = logprob_info.decoded_token
assert not decoded_token.endswith("�"), (
f"Prompt: '{prompt}'\n"
f"Token {token_id} decoded to '{decoded_token}' which "
f"ends with replacement character"
)
def test_correct_decoded_token_preserves_valid_tokens():
"""Test that valid tokens (not ending with �) are not modified.
The _correct_decoded_token method should only be called for tokens
ending with "�", but this test verifies the broader _verify_tokens
logic doesn't affect valid tokens.
"""
runner = VllmRunner(
"facebook/opt-125m",
max_logprobs=2,
enable_prefix_caching=False,
gpu_memory_utilization=0.15,
max_model_len=256,
)
# Simple prompt with standard ASCII characters
test_prompts = ["Hello world, this is a test."]
sampling_params = SamplingParams(
max_tokens=10,
temperature=0,
logprobs=2,
)
results = runner.llm.generate(test_prompts, sampling_params=sampling_params)
for result in results:
assert result.outputs[0].logprobs is not None
# All decoded tokens should be valid strings
for logprob_dict in result.outputs[0].logprobs:
for token_id, logprob_info in logprob_dict.items():
decoded_token = logprob_info.decoded_token
# Valid tokens should be non-empty strings (or empty if corrected)
assert isinstance(decoded_token, str)
# Should not contain replacement character
assert "�" not in decoded_token
@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
@pytest.mark.parametrize(
"model_setup",
......@@ -524,32 +942,74 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
(
"eagle",
"meta-llama/Llama-3.2-1B-Instruct",
"nm-testing/Llama3_2_1B_speculator.eagle3",
{
"method": "eagle",
"model": "nm-testing/Llama3_2_1B_speculator.eagle3",
"num_speculative_tokens": 3,
},
0,
),
marks=large_gpu_mark(min_gb=32),
id="eagle0",
),
pytest.param(
(
"eagle",
"meta-llama/Llama-3.2-1B-Instruct",
{
"method": "eagle",
"model": "nm-testing/Llama3_2_1B_speculator.eagle3",
"num_speculative_tokens": 3,
},
3,
),
marks=large_gpu_mark(min_gb=32),
id="eagle3",
),
pytest.param(
(
"ngram",
"meta-llama/Llama-3.2-1B-Instruct",
{
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
3,
),
marks=large_gpu_mark(min_gb=32),
id="ngram",
),
],
)
@pytest.mark.parametrize("top_logprobs", [0, 3])
def test_spec_decode_logprobs(
logprobs_mode: LogprobsMode,
model_setup: tuple[str, str, str],
top_logprobs: int,
model_setup: tuple[str, str, dict, int],
):
"""Spec decode logprobs should match those of the base model.
Args:
logprobs_mode: logprobs mode.
model_setup: Spec decode method, base model name, and
draft model name.
model_setup: Tuple of (method, base model name,
speculative_config dict, top_logprobs).
"""
from vllm import LLM
method, model_name, spec_config, top_logprobs = model_setup
prompt = "Hello world " * 50
sampling_params = SamplingParams(
temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False
)
method, model_name, spec_model_name = model_setup
penalty_sampling_params = SamplingParams(
temperature=0,
logprobs=top_logprobs,
max_tokens=10,
ignore_eos=False,
presence_penalty=-1.0,
)
max_model_len = 256
# Run base LLM.
......@@ -560,27 +1020,27 @@ def test_spec_decode_logprobs(
seed=42,
logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4,
enable_prefix_caching=False,
)
ref_results = ref_llm.generate(
[prompt, prompt], [sampling_params, penalty_sampling_params]
)
ref_results = ref_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from reference LLM.
ref_logprobs = []
for output in ref_results[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
ref_logprobs.append(logprobs[token_id])
for results in ref_results:
for output in results.outputs:
for logprobs in output.logprobs:
ref_logprobs.extend(logprobs.values())
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
# Run spec decode LLM.
# Add max_model_len to spec_config if not present
spec_config_with_len = {**spec_config, "max_model_len": max_model_len}
spec_llm = LLM(
model_name,
speculative_config={
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": max_model_len,
},
speculative_config=spec_config_with_len,
max_logprobs=5,
max_model_len=max_model_len,
seed=42,
......@@ -589,14 +1049,17 @@ def test_spec_decode_logprobs(
# Force prefill chunking
enable_chunked_prefill=True,
max_num_batched_tokens=32,
enable_prefix_caching=False,
)
spec_results = spec_llm.generate(
[prompt, prompt], [sampling_params, penalty_sampling_params]
)
spec_results = spec_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from spec decode LLM.
spec_logprobs = []
for output in spec_results[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
spec_logprobs.append(logprobs[token_id])
for results in spec_results:
for output in results.outputs:
for logprobs in output.logprobs:
spec_logprobs.extend(logprobs.values())
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
......
......@@ -691,9 +691,13 @@ def test_frequency_penalties(rejection_sampler):
def test_bad_words(rejection_sampler):
"""Test rejection sampling with bad words constraints"""
"""Test rejection sampling with bad words constraints.
This test applies bad words to non-consecutive requests (0 and 2, but not 1)
to verify correct logit indexing when iterating over requests with bad words.
"""
spec_tokens = [[1, 2, 3], [1, 15, 3], [1, 2, 3]]
output_tokens = [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]
output_tokens = [[1, 2, 3, 4], [1, 15, 3, 4], [1, 2, 3, 4]]
logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
metadata = create_sampling_metadata(
......@@ -701,17 +705,9 @@ def test_bad_words(rejection_sampler):
output_token_ids=[[2], [3], [4]],
spec_token_ids=spec_tokens,
bad_words_token_ids={
0: [
[
2,
]
],
1: [
[
2,
]
],
# Do not apply bad words to the last request
0: [[2]],
# Request 1 has no bad words (to test non-consecutive request handling)
2: [[2]],
},
)
bonus_token_tensor = torch.tensor(
......@@ -726,8 +722,11 @@ def test_bad_words(rejection_sampler):
sampling_metadata=metadata,
)
# Request 0: bad word [2] matches prefix, so token 2 is rejected -> 15
# Request 1: no bad words, all tokens match -> [1, 15, 3, 4]
# Request 2: bad word [2] matches prefix, so token 2 is rejected -> 15
expected = torch.tensor(
[[1, 15, -1, -1], [1, 15, 3, 4], [1, 2, 3, 4]],
[[1, 15, -1, -1], [1, 15, 3, 4], [1, 15, -1, -1]],
dtype=torch.int,
device=logits.device,
)
......
......@@ -14,8 +14,8 @@ from tests.v1.attention.utils import (
create_standard_kv_cache_spec,
try_get_attention_backend,
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import (
AttentionConfig,
CacheConfig,
DeviceConfig,
ModelConfig,
......@@ -27,6 +27,7 @@ from vllm.config import (
from vllm.config.load import LoadConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
......@@ -41,6 +42,7 @@ eagle3_dir = os.path.join(models_path_prefix, "yuhuili/EAGLE3-LLaMA3.1-Instruct-
def _create_proposer(
method: str,
num_speculative_tokens: int,
attention_backend: str | None = None,
speculative_token_tree: list[tuple[int, ...]] | None = None,
) -> EagleProposer:
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
......@@ -73,6 +75,7 @@ def _create_proposer(
max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
),
attention_config=AttentionConfig(backend=attention_backend),
)
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
......@@ -306,10 +309,16 @@ def test_prepare_inputs_padded():
proposer = _create_proposer("eagle", num_speculative_tokens)
output_metadata, token_indices_to_sample = proposer.prepare_inputs_padded(
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
output_metadata, token_indices_to_sample, num_rejected_tokens_gpu = (
proposer.prepare_inputs_padded(
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
)
)
# Verify num_rejected_tokens_gpu is calculated correctly
expected_num_rejected = torch.tensor([1, 0, 2], dtype=torch.int32, device=device)
assert torch.equal(num_rejected_tokens_gpu, expected_num_rejected)
assert output_metadata.max_query_len == 3
assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc)
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
......@@ -334,8 +343,6 @@ def test_load_model(
use_distinct_lm_head,
monkeypatch,
):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
......@@ -399,7 +406,9 @@ def test_load_model(
assert not isinstance(target_model, SupportsMultiModal)
# Create proposer using the helper function
proposer = _create_proposer(method, num_speculative_tokens=8)
proposer = _create_proposer(
method, num_speculative_tokens=8, attention_backend=attn_backend
)
# Call the method under test
proposer.load_model(target_model)
......@@ -425,8 +434,6 @@ def test_load_model(
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
......@@ -454,7 +461,9 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
seq_lens = [seq_len_1, seq_len_2]
# Create proposer first so we can use its actual hidden_size
proposer = _create_proposer("eagle", num_speculative_tokens)
proposer = _create_proposer(
"eagle", num_speculative_tokens, attention_backend=attn_backend
)
# Get the hidden_size from the proposer to ensure consistency
hidden_size = proposer.hidden_size
......@@ -627,7 +636,9 @@ def test_propose_tree(spec_token_tree):
# Create proposer first so we can use its actual hidden_size.
proposer = _create_proposer(
"eagle", num_speculative_tokens, speculative_token_tree=spec_token_tree
"eagle",
num_speculative_tokens,
speculative_token_tree=spec_token_tree,
)
# Get the hidden_size from the proposer to ensure consistency.
hidden_size = proposer.hidden_size
......
......@@ -13,7 +13,6 @@ from tests.v1.attention.utils import (
create_standard_kv_cache_spec,
try_get_attention_backend,
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import (
CacheConfig,
DeviceConfig,
......@@ -26,6 +25,7 @@ from vllm.config import (
from vllm.config.load import LoadConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.spec_decode.eagle import EagleProposer
from ...utils import models_path_prefix
......
......@@ -85,10 +85,8 @@ def test_ngram_proposer():
token_ids_cpu = np.array([[1, 2, 3, 4, 5]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 0
......@@ -96,10 +94,8 @@ def test_ngram_proposer():
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
result = get_ngram_proposer(min_n=4, max_n=4, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 0
......@@ -107,10 +103,8 @@ def test_ngram_proposer():
token_ids_cpu = np.array([[1, 2, 3, 4, 1, 2, 3]])
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert np.array_equal(result, np.array([[4, 1]]))
......@@ -119,10 +113,8 @@ def test_ngram_proposer():
token_ids_cpu = np.array([[2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]])
result = get_ngram_proposer(min_n=3, max_n=4, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 1]]
......@@ -130,10 +122,8 @@ def test_ngram_proposer():
token_ids_cpu = np.array([[3, 4, 5, 2, 3, 4, 1, 2, 3, 4]])
result = get_ngram_proposer(min_n=2, max_n=4, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 2]]
......@@ -141,10 +131,8 @@ def test_ngram_proposer():
token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]])
result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert np.array_equal(result, np.array([[100, 1]]))
......@@ -152,10 +140,8 @@ def test_ngram_proposer():
token_ids_cpu = np.array([[]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
sampled_token_ids=[[0]],
req_ids=["0"],
num_tokens_no_spec=np.array([len(c) for c in token_ids_cpu]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 0
......@@ -165,10 +151,8 @@ def test_ngram_proposer():
token_ids_cpu = np.array([[1, 2, 3, 1, 2], [4, 5, 6, -1, -1]])
result = get_ngram_proposer(min_n=2, max_n=2, k=2).propose(
sampled_token_ids=[[0], [1]],
req_ids=["0", "1"],
num_tokens_no_spec=np.array([5, 3]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 2
assert np.array_equal(result[0], np.array([3, 1]))
......@@ -186,10 +170,8 @@ def test_ngram_proposer():
sampled_token_ids = [[2], [], [8]] # Empty list for request 1 simulates prefill
result = proposer.propose(
sampled_token_ids=sampled_token_ids,
req_ids=["0", "1", "2"],
num_tokens_no_spec=num_tokens_no_spec,
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result) == 3
assert np.array_equal(result[0], [3, 1])
......@@ -217,10 +199,8 @@ def test_ngram_proposer():
token_ids_cpu = np.array([input_1, input_2])
result = ngram_proposer.propose(
sampled_token_ids=[[0], [1]],
req_ids=["0", "1"],
num_tokens_no_spec=np.array([len(input_1), 3]),
token_ids_cpu=token_ids_cpu,
spec_decode_unsupported_reqs=(),
)
assert len(result[0]) == 2
assert np.array_equal(result[0], np.array([middle_integer + 2, middle_integer + 3]))
......
......@@ -11,10 +11,10 @@ from tests.v1.attention.utils import (
create_vllm_config,
try_get_attention_backend,
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available
from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
from vllm.v1.attention.backends.registry import AttentionBackendEnum
if not is_flash_attn_varlen_func_available():
pytest.skip(
......
......@@ -38,53 +38,48 @@ def test_ngram_max_len(num_speculative_tokens: int):
def test_eagle_max_len(
monkeypatch: pytest.MonkeyPatch, num_speculative_tokens: int, attn_backend: str
):
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1")
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct",
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "eagle",
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": num_speculative_tokens,
"max_model_len": 80,
},
max_model_len=200,
llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct",
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "eagle",
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": num_speculative_tokens,
"max_model_len": 80,
},
max_model_len=200,
attention_config={"backend": attn_backend},
)
sampling_params = SamplingParams(max_tokens=200, ignore_eos=True)
outputs = llm.generate(_PROMPTS, sampling_params)
for o in outputs:
assert o.outputs[0].finish_reason == "length", (
"This test is only meaningful if the output is truncated due to max length"
)
sampling_params = SamplingParams(max_tokens=200, ignore_eos=True)
outputs = llm.generate(_PROMPTS, sampling_params)
for o in outputs:
assert o.outputs[0].finish_reason == "length", (
"This test is only meaningful if the output "
"is truncated due to max length"
)
sampling_params = SamplingParams(
max_tokens=200,
structured_outputs=StructuredOutputsParams(
regex="^" + "a b c d e " * 15 + "$"
),
sampling_params = SamplingParams(
max_tokens=200,
structured_outputs=StructuredOutputsParams(regex="^" + "a b c d e " * 15 + "$"),
)
output = llm.generate(_PROMPTS, sampling_params)
for o in output:
assert o.prompt_token_ids is not None
assert (
len(o.prompt_token_ids)
< 80
< len(o.prompt_token_ids) + len(o.outputs[0].token_ids)
<= 200
), (
"This test is only meaningful if the output "
"is longer than the eagle max length"
)
output = llm.generate(_PROMPTS, sampling_params)
for o in output:
assert o.prompt_token_ids is not None
assert (
len(o.prompt_token_ids)
< 80
< len(o.prompt_token_ids) + len(o.outputs[0].token_ids)
<= 200
), (
"This test is only meaningful if the output "
"is longer than the eagle max length"
)
assert o.outputs[0].text == "a b c d e " * 15
assert o.outputs[0].text == "a b c d e " * 15
......@@ -71,6 +71,7 @@ class TestReasoningStructuredOutput:
request.prompt_token_ids = [1, 2, 3, 4, 5]
request.all_token_ids = [1, 2, 3, 4, 5, 6, 7, 8]
request.num_computed_tokens = 5
request.num_output_placeholders = 0
return request
def test_should_fill_bitmask_with_enable_in_reasoning(
......
......@@ -6,8 +6,6 @@ import numpy as np
import pytest
import torch
from vllm.attention.backends.abstract import MultipleOf
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import Attention
from vllm.config import (
AttentionConfig,
......@@ -27,6 +25,9 @@ from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.system_utils import update_environment_variables
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backend import MultipleOf
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
from vllm.v1.kv_cache_interface import (
......@@ -113,15 +114,16 @@ def get_vllm_config():
@pytest.fixture
def model_runner():
vllm_config = get_vllm_config()
model_config = vllm_config.model_config
num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config)
head_size = model_config.get_head_size()
vllm_config.compilation_config.static_forward_context["layer.0"] = Attention(
num_heads, head_size, 0.1
)
runner = GPUModelRunner(vllm_config, DEVICE)
initialize_kv_cache(runner)
return runner
with set_current_vllm_config(vllm_config):
model_config = vllm_config.model_config
num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config)
head_size = model_config.get_head_size()
vllm_config.compilation_config.static_forward_context["layer.0"] = Attention(
num_heads, head_size, 0.1
)
runner = GPUModelRunner(vllm_config, DEVICE)
initialize_kv_cache(runner)
yield runner
model_runner_2 = model_runner
......@@ -547,7 +549,7 @@ def test_reload_weights_before_load_model(model_runner):
model_runner.reload_weights()
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(default_vllm_config):
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
......@@ -574,7 +576,7 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
assert fwd_context is not None
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(default_vllm_config):
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
......@@ -601,7 +603,7 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
assert fwd_context is not None
def test_init_kv_cache_with_kv_sharing_target_same_as_current():
def test_init_kv_cache_with_kv_sharing_target_same_as_current(default_vllm_config):
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
......@@ -628,7 +630,7 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
assert fwd_context is not None
def test_init_kv_cache_without_kv_sharing():
def test_init_kv_cache_without_kv_sharing(default_vllm_config):
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
......@@ -695,7 +697,7 @@ def test_init_kv_cache_without_kv_sharing():
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
def test_init_kv_cache_with_kv_sharing_valid():
def test_init_kv_cache_with_kv_sharing_valid(default_vllm_config):
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
......@@ -778,7 +780,7 @@ def test_hybrid_attention_mamba_tensor_shapes():
will not corrupt an attention block and vice versa
"""
current_platform.seed_everything(42)
set_random_seed(42)
update_environment_variables(
{
......@@ -1048,7 +1050,7 @@ def test_input_batch_with_kernel_block_sizes():
assert block_table.block_size == kernel_size
def test_hybrid_cache_integration(model_runner, dist_init):
def test_hybrid_cache_integration(default_vllm_config, dist_init):
"""Test hybrid cache architecture integration with GPUModelRunner."""
# Create a new model runner with hybrid cache configuration
vllm_config = get_vllm_config()
......@@ -1112,3 +1114,87 @@ def test_hybrid_cache_integration(model_runner, dist_init):
runner._update_states(scheduler_output)
assert _is_req_scheduled(runner, req_id)
assert _is_req_state_block_table_match(runner, req_id)
def test_is_uniform_decode() -> None:
# Normal
assert GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=2,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=15,
)
# Spec decoding
assert GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=5,
uniform_decode_query_len=5,
num_tokens=30,
num_reqs=6,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=5,
uniform_decode_query_len=4,
num_tokens=30,
num_reqs=6,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=5,
uniform_decode_query_len=5,
num_tokens=30,
num_reqs=7,
)
# Force uniform decode
assert GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
force_uniform_decode=True,
)
assert GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=2,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
force_uniform_decode=True,
)
assert GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=15,
force_uniform_decode=True,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
force_uniform_decode=False,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=2,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
force_uniform_decode=False,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=15,
force_uniform_decode=False,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment