Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7e63ef82
Commit
7e63ef82
authored
Jan 21, 2026
by
zhuwenwen
Browse files
Merge tag 'v0.14.0' into v0.14.0-dev
parents
8cbcac5d
b17039bc
Changes
681
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2946 additions
and
354 deletions
+2946
-354
tests/v1/kv_connector/unit/test_example_connector.py
tests/v1/kv_connector/unit/test_example_connector.py
+3
-3
tests/v1/kv_connector/unit/test_lmcache_connector.py
tests/v1/kv_connector/unit/test_lmcache_connector.py
+30
-0
tests/v1/kv_connector/unit/test_moriio_connector.py
tests/v1/kv_connector/unit/test_moriio_connector.py
+545
-0
tests/v1/kv_connector/unit/test_multi_connector.py
tests/v1/kv_connector/unit/test_multi_connector.py
+45
-0
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+517
-129
tests/v1/kv_connector/unit/test_offloading_connector.py
tests/v1/kv_connector/unit/test_offloading_connector.py
+194
-75
tests/v1/kv_connector/unit/utils.py
tests/v1/kv_connector/unit/utils.py
+11
-1
tests/v1/kv_offload/test_cpu_gpu.py
tests/v1/kv_offload/test_cpu_gpu.py
+3
-1
tests/v1/kv_offload/test_cpu_offloading.py
tests/v1/kv_offload/test_cpu_offloading.py
+10
-11
tests/v1/kv_offload/test_worker.py
tests/v1/kv_offload/test_worker.py
+12
-0
tests/v1/metrics/test_perf_metrics.py
tests/v1/metrics/test_perf_metrics.py
+907
-0
tests/v1/sample/test_logprobs.py
tests/v1/sample/test_logprobs.py
+486
-23
tests/v1/sample/test_rejection_sampler.py
tests/v1/sample/test_rejection_sampler.py
+13
-14
tests/v1/spec_decode/test_eagle.py
tests/v1/spec_decode/test_eagle.py
+21
-10
tests/v1/spec_decode/test_mtp.py
tests/v1/spec_decode/test_mtp.py
+1
-1
tests/v1/spec_decode/test_ngram.py
tests/v1/spec_decode/test_ngram.py
+0
-20
tests/v1/spec_decode/test_tree_attention.py
tests/v1/spec_decode/test_tree_attention.py
+3
-3
tests/v1/spec_decode/untest_max_len.py
tests/v1/spec_decode/untest_max_len.py
+40
-45
tests/v1/structured_output/test_reasoning_structured_output.py
.../v1/structured_output/test_reasoning_structured_output.py
+1
-0
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+104
-18
No files found.
Too many changes to show.
To preserve performance only
681 of 681+
files are displayed.
Plain diff
Email patch
tests/v1/kv_connector/unit/test_example_connector.py
View file @
7e63ef82
...
...
@@ -9,7 +9,7 @@ from PIL import Image
from
vllm
import
LLM
,
EngineArgs
,
SamplingParams
from
vllm.assets.image
import
ImageAsset
from
vllm.config
import
KVTransferConfig
from
vllm.multimodal.utils
import
encode_image_
base64
from
vllm.multimodal.utils
import
encode_image_
url
from
vllm.platforms
import
current_platform
MODEL_NAME
=
"RedHatAI/Qwen2.5-VL-3B-Instruct-quantized.w8a8"
...
...
@@ -74,7 +74,7 @@ def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]):
placeholders
=
[
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
f
"data:image;base64,
{
encode_image_
base64
(
image_pil
)
}
"
}
,
"image_url"
:
{
"url"
:
encode_image_
url
(
image_pil
)},
}
for
image_pil
in
image_urls
]
...
...
@@ -145,7 +145,7 @@ def test_shared_storage_connector_hashes(tmp_path):
# don't put this import at the top level
# it will call torch.cuda.device_count()
from
transformers
import
AutoProcessor
# noqa: F401
from
transformers
import
AutoProcessor
# Create processor to handle the chat prompt
processor
=
AutoProcessor
.
from_pretrained
(
MODEL_NAME
)
...
...
tests/v1/kv_connector/unit/test_lmcache_connector.py
View file @
7e63ef82
...
...
@@ -25,6 +25,7 @@ def mock_lmcache_engine_event():
lora_id
,
block_size
,
medium
,
lora_name
,
):
self
.
block_hashes
=
block_hashes
self
.
parent_block_hash
=
parent_block_hash
...
...
@@ -32,6 +33,7 @@ def mock_lmcache_engine_event():
self
.
lora_id
=
lora_id
self
.
block_size
=
block_size
self
.
medium
=
medium
self
.
lora_name
=
lora_name
return
MockEvent
(
block_hashes
=
[
"hash1"
,
"hash2"
],
...
...
@@ -40,6 +42,7 @@ def mock_lmcache_engine_event():
lora_id
=
None
,
block_size
=
16
,
medium
=
"GPU"
,
lora_name
=
None
,
)
...
...
@@ -109,6 +112,7 @@ class TestGetKVConnectorKVCacheEvents:
assert
events
[
0
].
lora_id
is
None
assert
events
[
0
].
block_size
==
16
assert
events
[
0
].
medium
==
"GPU"
assert
events
[
0
].
lora_name
is
None
def
test_converts_multiple_events
(
self
,
mock_connector
):
"""Test conversion of multiple events from lmcache engine format."""
...
...
@@ -121,6 +125,7 @@ class TestGetKVConnectorKVCacheEvents:
self
.
lora_id
=
None
self
.
block_size
=
16
self
.
medium
=
"GPU"
self
.
lora_name
=
None
events
=
[
MockEvent
(
i
)
for
i
in
range
(
5
)]
mock_connector
.
_lmcache_engine
.
get_kv_events
.
return_value
=
events
...
...
@@ -150,6 +155,7 @@ class TestGetKVConnectorKVCacheEvents:
self
.
lora_id
=
42
self
.
block_size
=
32
self
.
medium
=
"DISK"
self
.
lora_name
=
"lora_example"
mock_connector
.
_lmcache_engine
.
get_kv_events
.
return_value
=
[
MockEventWithLora
()
...
...
@@ -166,6 +172,7 @@ class TestGetKVConnectorKVCacheEvents:
assert
event
.
lora_id
==
42
assert
event
.
block_size
==
32
assert
event
.
medium
==
"DISK"
assert
event
.
lora_name
==
"lora_example"
def
test_handles_none_parent_block_hash
(
self
,
mock_connector
):
"""Test handling of events with None parent_block_hash."""
...
...
@@ -178,6 +185,7 @@ class TestGetKVConnectorKVCacheEvents:
self
.
lora_id
=
None
self
.
block_size
=
16
self
.
medium
=
"GPU"
self
.
lora_name
=
None
mock_connector
.
_lmcache_engine
.
get_kv_events
.
return_value
=
[
MockEventNoParent
()
...
...
@@ -223,6 +231,7 @@ class TestUpdateConnectorOutput:
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
lora_name
=
None
,
)
kv_events
.
add_events
([
event
])
...
...
@@ -243,6 +252,7 @@ class TestUpdateConnectorOutput:
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
lora_name
=
None
,
)
existing_events
.
add_events
([
event1
])
existing_events
.
add_events
([
event1
])
# Simulate 2 workers reporting
...
...
@@ -258,6 +268,7 @@ class TestUpdateConnectorOutput:
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
lora_name
=
None
,
)
new_events
.
add_events
([
event2
])
...
...
@@ -288,6 +299,7 @@ class TestUpdateConnectorOutput:
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
lora_name
=
None
,
)
new_events
.
add_events
([
event
])
...
...
@@ -309,6 +321,7 @@ class TestUpdateConnectorOutput:
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
lora_name
=
None
,
)
events1
.
add_events
([
event1
])
output1
=
KVConnectorOutput
(
kv_cache_events
=
events1
)
...
...
@@ -323,6 +336,7 @@ class TestUpdateConnectorOutput:
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
lora_name
=
None
,
)
events2
.
add_events
([
event2
])
output2
=
KVConnectorOutput
(
kv_cache_events
=
events2
)
...
...
@@ -337,6 +351,7 @@ class TestUpdateConnectorOutput:
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
lora_name
=
None
,
)
events3
.
add_events
([
event3
])
output3
=
KVConnectorOutput
(
kv_cache_events
=
events3
)
...
...
@@ -358,6 +373,7 @@ class TestUpdateConnectorOutput:
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
lora_name
=
None
,
)
events1
.
add_events
([
event1
])
output1
=
KVConnectorOutput
(
kv_cache_events
=
events1
)
...
...
@@ -397,6 +413,7 @@ class TestTakeEvents:
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
lora_name
=
None
,
)
event2
=
BlockStored
(
block_hashes
=
[
"hash2"
],
...
...
@@ -405,6 +422,7 @@ class TestTakeEvents:
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
lora_name
=
None
,
)
kv_events
.
add_events
([
event1
,
event2
])
mock_connector
.
_kv_cache_events
=
kv_events
...
...
@@ -431,6 +449,7 @@ class TestTakeEvents:
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
lora_name
=
None
,
)
uncommon_event
=
BlockStored
(
block_hashes
=
[
"hash_uncommon"
],
...
...
@@ -439,6 +458,7 @@ class TestTakeEvents:
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
lora_name
=
None
,
)
# All 3 workers report common_event
...
...
@@ -469,6 +489,7 @@ class TestTakeEvents:
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
lora_name
=
None
,
)
kv_events1
.
add_events
([
event1
])
mock_connector
.
_kv_cache_events
=
kv_events1
...
...
@@ -491,6 +512,7 @@ class TestTakeEvents:
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
lora_name
=
None
,
)
kv_events2
.
add_events
([
event2
])
mock_connector
.
_kv_cache_events
=
kv_events2
...
...
@@ -510,6 +532,7 @@ class TestTakeEvents:
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
lora_name
=
None
,
)
event2
=
BlockStored
(
block_hashes
=
[
"hash2"
],
...
...
@@ -518,6 +541,7 @@ class TestTakeEvents:
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
lora_name
=
None
,
)
# Worker 1 reports event1
...
...
@@ -572,6 +596,7 @@ class TestIntegrationScenarios:
self
.
lora_id
=
None
self
.
block_size
=
16
self
.
medium
=
"GPU"
self
.
lora_name
=
None
# Worker 1
mock_connector
.
_lmcache_engine
.
get_kv_events
.
return_value
=
[
...
...
@@ -628,6 +653,7 @@ class TestIntegrationScenarios:
self
.
lora_id
=
None
self
.
block_size
=
16
self
.
medium
=
"GPU"
self
.
lora_name
=
None
for
cycle
in
range
(
3
):
# Get events
...
...
@@ -667,6 +693,7 @@ class TestIntegrationScenarios:
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
lora_name
=
None
,
)
worker1_unique_event
=
BlockStored
(
...
...
@@ -676,6 +703,7 @@ class TestIntegrationScenarios:
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
lora_name
=
None
,
)
worker2_unique_event
=
BlockStored
(
...
...
@@ -685,6 +713,7 @@ class TestIntegrationScenarios:
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
lora_name
=
None
,
)
worker3_unique_event
=
BlockStored
(
...
...
@@ -694,6 +723,7 @@ class TestIntegrationScenarios:
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
lora_name
=
None
,
)
# Create events for each worker
...
...
tests/v1/kv_connector/unit/test_moriio_connector.py
0 → 100644
View file @
7e63ef82
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
importlib.util
import
os
from
unittest.mock
import
MagicMock
,
patch
import
msgspec
import
pytest
import
torch
import
zmq
from
tests.conftest
import
_find_free_port
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
KVTransferConfig
,
ModelConfig
,
SchedulerConfig
,
VllmConfig
,
)
from
vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common
import
(
MoRIIOAgentMetadata
,
MoRIIOConnectorMetadata
,
MoRIIOConstants
,
zmq_ctx
,
)
from
vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector
import
(
KVConnectorRole
,
MoRIIOConnector
,
MoRIIOConnectorWorker
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.network_utils
import
(
get_ip
,
make_zmq_path
,
)
from
.utils
import
create_request
,
create_scheduler
aiter_available
=
importlib
.
util
.
find_spec
(
"aiter"
)
is
not
None
mori_available
=
importlib
.
util
.
find_spec
(
"mori"
)
is
not
None
pytestmark
=
pytest
.
mark
.
skipif
(
not
(
current_platform
.
is_rocm
()
and
mori_available
),
reason
=
"MoRIIOs are only available on ROCm with aiter package installed"
,
)
@
pytest
.
fixture
def
mock_parallel_groups
():
"""Mock tensor/data parallel group functions for single-rank tests."""
mock_group
=
MagicMock
()
mock_group
.
rank
=
0
mock_group
.
local_rank
=
0
mock_group
.
world_size
=
1
with
(
patch
.
multiple
(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common"
,
get_tensor_model_parallel_rank
=
MagicMock
(
return_value
=
0
),
get_tensor_model_parallel_world_size
=
MagicMock
(
return_value
=
0
),
),
patch
.
multiple
(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector"
,
get_tensor_model_parallel_world_size
=
MagicMock
(
return_value
=
0
),
get_world_group
=
MagicMock
(
return_value
=
mock_group
),
get_tp_group
=
MagicMock
(
return_value
=
mock_group
),
),
):
yield
mock_group
def
_setup_kv_transfer_request
(
request
,
remote_host
=
"127.0.0.1"
,
fake_port
=
4789
):
"""Setup KV transfer parameters for a request."""
request
.
kv_transfer_params
.
update
(
{
"remote_notify_port"
:
fake_port
,
"remote_block_ids"
:
None
,
"remote_host"
:
remote_host
,
"remote_port"
:
fake_port
,
"remote_handshake_port"
:
fake_port
,
"remote_engine_id"
:
"test_engine"
,
}
)
return
request
class
FakeMorIIOWrapper
:
# A fake MoRIIOWrapper for testing purposes
def
__init__
(
self
,
*
args
,
**
kwargs
):
pass
def
set_moriio_engine
(
self
,
moriio_engine
):
pass
def
set_backend_type
(
self
,
backend_type
):
pass
def
get_agent_metadata
(
self
):
pass
def
register_remote_engine
(
self
,
remote_packed_engine_metadata
):
pass
def
register_local_tensor
(
self
,
tensor
:
torch
.
Tensor
):
pass
def
get_unpack_memory_metadata
(
self
,
packed_memory_metadata
):
pass
def
build_session
(
self
,
local_memory_metadata
,
remote_memory_metadata
):
pass
def
read_remote_data
(
self
,
transfer_size_byte
,
local_offset
=
0
,
remote_offset
=
0
,
session
=
None
):
pass
def
write_remote_data
(
self
,
transfer_size_byte
,
local_offset
=
0
,
remote_offset
=
0
,
session
=
None
):
pass
def
write_remote_data_single
(
self
,
transfer_size_byte
,
local_offset
=
0
,
remote_offset
=
0
,
sess_idx
=
0
):
pass
def
waiting_for_transfer_complete
(
self
):
pass
def
async_wait_reqid
(
self
):
pass
def
_handle_message
(
self
,
msg
:
bytes
):
pass
def
_handle_structured_message
(
self
,
data
:
dict
):
pass
def
_handle_completion_message
(
self
,
msg
:
str
):
pass
def
send_notify
(
self
,
req_ids
,
remote_ip
,
remote_port
):
pass
def
pop_finished_req_ids
(
self
):
pass
def
pop_finished_write_req_ids
(
self
):
pass
def
shutdown
(
self
):
pass
class
FakeMorIIOConnectorWorker
(
MoRIIOConnectorWorker
):
# Define a fake remote engine id for testing
REMOTE_ENGINE_ID
=
"remote_engine"
def
__init__
(
self
,
*
args
,
hand_shake_latency
:
float
=
1.8
,
kv_cache_layout
=
"HND"
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
def
create_vllm_config
(
model
:
str
=
"facebook/opt-125m"
,
max_num_seqs
:
int
=
16
,
max_num_batched_tokens
:
int
=
64
,
block_size
:
int
=
16
,
max_model_len
:
int
=
10000
,
enable_chunked_prefill
:
bool
=
True
,
enable_permute_local_kv
:
bool
=
False
,
role
=
"kv_consumer"
,
)
->
VllmConfig
:
"""Initialize VllmConfig for testing."""
scheduler_config
=
SchedulerConfig
(
max_num_seqs
=
max_num_seqs
,
max_num_batched_tokens
=
max_num_batched_tokens
,
max_model_len
=
max_model_len
,
enable_chunked_prefill
=
enable_chunked_prefill
,
is_encoder_decoder
=
False
,
)
model_config
=
ModelConfig
(
model
=
model
,
trust_remote_code
=
True
,
dtype
=
"bfloat16"
,
seed
=
42
,
)
# Cache config, optionally force APC
cache_config
=
CacheConfig
(
block_size
=
block_size
,
gpu_memory_utilization
=
0.9
,
swap_space
=
0
,
cache_dtype
=
"auto"
,
enable_prefix_caching
=
True
,
)
kv_transfer_config
=
KVTransferConfig
(
kv_connector
=
"MoRIIOConnector"
,
kv_role
=
role
,
enable_permute_local_kv
=
enable_permute_local_kv
,
)
return
VllmConfig
(
scheduler_config
=
scheduler_config
,
model_config
=
model_config
,
cache_config
=
cache_config
,
kv_transfer_config
=
kv_transfer_config
,
device_config
=
DeviceConfig
(
"cpu"
),
)
@
pytest
.
fixture
def
moriio_read_mode
():
"""Force the connector into read mode via env for tests."""
os
.
environ
[
"VLLM_MORIIO_CONNECTOR_READ_MODE"
]
=
"True"
yield
# Cleanup after test
os
.
environ
.
pop
(
"VLLM_MORIIO_CONNECTOR_READ_MODE"
,
None
)
def
test_write_mode_saves_local_block_ids
():
"""Write mode records local block ids in MoRIIOConnectorMetadata.reqs_to_save."""
# Setup Scheduler and Request
vllm_config
=
create_vllm_config
(
role
=
"kv_producer"
)
scheduler
=
create_scheduler
(
vllm_config
)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE
=
vllm_config
.
cache_config
.
block_size
NUM_EXTERNAL_FULL_BLOCKS
=
2
NUM_TOKENS
=
int
(
BLOCK_SIZE
*
(
NUM_EXTERNAL_FULL_BLOCKS
+
0.5
))
request
=
create_request
(
request_id
=
1
,
block_size
=
BLOCK_SIZE
,
num_tokens
=
NUM_TOKENS
,
do_remote_decode
=
True
,
do_remote_prefill
=
False
,
)
request_id
=
request
.
request_id
scheduler
.
add_request
(
request
)
# Fake Config
request
=
_setup_kv_transfer_request
(
request
)
# Remote Prefill, triggers MoRIIOConnectorMetadata.
scheduler_output
=
scheduler
.
schedule
()
kv_connector_metadata
=
scheduler_output
.
kv_connector_metadata
assert
kv_connector_metadata
is
not
None
,
"kv_connector_metadata is None"
assert
isinstance
(
kv_connector_metadata
,
MoRIIOConnectorMetadata
)
assert
len
(
kv_connector_metadata
.
reqs_to_save
)
==
1
,
(
"Unexpected number of reqs_to_save"
)
assert
len
(
kv_connector_metadata
.
reqs_to_recv
)
==
0
,
(
"Unexpected number of reqs_to_recv"
)
assert
len
(
kv_connector_metadata
.
reqs_to_send
)
==
0
,
(
"Unexpected number of reqs_to_send"
)
assert
request_id
in
kv_connector_metadata
.
reqs_to_save
,
(
"Request ID not in reqs_to_save"
)
req_meta
=
kv_connector_metadata
.
reqs_to_save
[
request_id
]
for
block_id
,
block
in
zip
(
req_meta
.
local_block_ids
,
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
req_to_blocks
[
request_id
],
):
assert
block_id
==
block
.
block_id
,
f
"
{
block_id
}
!=
{
block
.
block_id
}
"
def
test_write_mode_with_chunked_prefill_saves_local_block_ids
():
"""Write mode with chunked prefill still records correct local block ids."""
# Setup Scheduler and Request
MAX_NUM_BATCHED_TOKENS
=
64
NUM_TOKENS
=
MAX_NUM_BATCHED_TOKENS
*
2
+
MAX_NUM_BATCHED_TOKENS
//
2
vllm_config
=
create_vllm_config
(
max_num_batched_tokens
=
MAX_NUM_BATCHED_TOKENS
,
role
=
"kv_producer"
)
BLOCK_SIZE
=
vllm_config
.
cache_config
.
block_size
scheduler
=
create_scheduler
(
vllm_config
)
# 2 Full Blocks and 1 Half Block.
request
=
create_request
(
request_id
=
1
,
block_size
=
BLOCK_SIZE
,
num_tokens
=
NUM_TOKENS
,
do_remote_decode
=
True
,
do_remote_prefill
=
False
,
)
request_id
=
request
.
request_id
scheduler
.
add_request
(
request
)
# Fake Config
request
=
_setup_kv_transfer_request
(
request
)
# Remote Prefill with chunked prefill, triggers multiple schedules.
expected_counts
=
[(
0
,
0
,
0
),
(
0
,
0
,
0
),
(
1
,
0
,
0
)]
kv_connector_metadata
=
None
for
_
,
(
expected_save
,
expected_recv
,
expected_send
)
in
enumerate
(
expected_counts
):
scheduler_output
=
scheduler
.
schedule
()
kv_connector_metadata
=
scheduler_output
.
kv_connector_metadata
assert
len
(
kv_connector_metadata
.
reqs_to_save
)
==
expected_save
assert
len
(
kv_connector_metadata
.
reqs_to_recv
)
==
expected_recv
assert
len
(
kv_connector_metadata
.
reqs_to_send
)
==
expected_send
assert
kv_connector_metadata
is
not
None
,
"kv_connector_metadata is None"
assert
request_id
in
kv_connector_metadata
.
reqs_to_save
,
(
"Request ID not in reqs_to_save"
)
req_meta
=
kv_connector_metadata
.
reqs_to_save
[
request_id
]
for
block_id
,
block
in
zip
(
req_meta
.
local_block_ids
,
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
req_to_blocks
[
request_id
],
):
assert
block_id
==
block
.
block_id
,
f
"
{
block_id
}
!=
{
block
.
block_id
}
"
def
test_read_mode_loads_remote_block_ids
(
moriio_read_mode
):
"""Read mode loads remote block ids into local cache mapping."""
# Setup Scheduler and Request
vllm_config
=
create_vllm_config
(
role
=
"kv_consumer"
)
scheduler
=
create_scheduler
(
vllm_config
)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE
=
vllm_config
.
cache_config
.
block_size
NUM_EXTERNAL_FULL_BLOCKS
=
2
NUM_TOKENS
=
int
(
BLOCK_SIZE
*
(
NUM_EXTERNAL_FULL_BLOCKS
+
0.5
))
request
=
create_request
(
request_id
=
1
,
block_size
=
BLOCK_SIZE
,
num_tokens
=
NUM_TOKENS
,
do_remote_decode
=
False
,
do_remote_prefill
=
True
,
)
request_id
=
request
.
request_id
scheduler
.
add_request
(
request
)
block_list
=
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
req_to_blocks
[
request_id
]
request
=
_setup_kv_transfer_request
(
request
)
# Set remote block ids to be fetched.
request
.
kv_transfer_params
[
"remote_block_ids"
]
=
block_list
# Remote Prefill, triggers MorIIOConnectorMetadata.
scheduler_output
=
scheduler
.
schedule
()
kv_connector_metadata
=
scheduler_output
.
kv_connector_metadata
assert
kv_connector_metadata
is
not
None
,
"kv_connector_metadata is None"
assert
isinstance
(
kv_connector_metadata
,
MoRIIOConnectorMetadata
),
(
"kv_connector_metadata is not MoRIIOConnectorMetadata"
)
assert
len
(
kv_connector_metadata
.
reqs_to_save
)
==
0
,
(
"Unexpected number of reqs_to_save"
)
assert
len
(
kv_connector_metadata
.
reqs_to_recv
)
==
1
,
(
"Unexpected number of reqs_to_recv"
)
assert
len
(
kv_connector_metadata
.
reqs_to_send
)
==
0
,
(
"Unexpected number of reqs_to_send"
)
assert
request_id
in
kv_connector_metadata
.
reqs_to_recv
,
(
"Request ID not in reqs_to_recv"
)
req_meta
=
kv_connector_metadata
.
reqs_to_recv
[
request_id
]
for
block_id
,
block
in
zip
(
req_meta
.
local_block_ids
,
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
req_to_blocks
[
request_id
],
):
assert
block_id
==
block
.
block_id
,
f
"
{
block_id
}
!=
{
block
.
block_id
}
"
@
pytest
.
mark
.
skipif
(
not
aiter_available
,
reason
=
"Requires aiter package for ROCm FlashAttention backend"
)
def
test_register_kv_caches
(
mock_parallel_groups
):
"""Test that MoRIIOConnector.register_kv_caches correctly registers kv caches."""
ROLE
=
"kv_consumer"
IP
=
get_ip
()
vllm_config
=
create_vllm_config
(
role
=
ROLE
)
DEFAULT_PORT
=
6301
TP_RANK
=
0
DP_RANK
=
0
from
vllm.v1.attention.backends.rocm_aiter_fa
import
AiterFlashAttentionBackend
backend_cls
=
AiterFlashAttentionBackend
# Create test kv cache tensors using proper backend shape
kv_cache_shape
=
backend_cls
.
get_kv_cache_shape
(
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
)
shared_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
unique_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
kv_caches
=
{
"layer0"
:
shared_tensor
,
"layer1"
:
unique_tensor
,
"layer2"
:
shared_tensor
,
}
with
(
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Event"
),
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Thread"
),
):
# Create connector
vllm_config
.
kv_transfer_config
.
kv_connector_extra_config
.
update
(
{
"proxy_ip"
:
"127.0.0.1"
,
"proxy_ping_port"
:
12345
,
"http_port"
:
12346
,
}
)
connector
=
MoRIIOConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
.
connector_worker
=
FakeMorIIOConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
from
mori.io
import
(
MemoryDesc
,
)
# Execute register_kv_caches
connector
.
register_kv_caches
(
kv_caches
)
# Verify that the MemoryDesc stored in layer_name_to_local_kv_cache_metadata
assert
(
shared_tensor
.
data_ptr
()
==
MemoryDesc
.
unpack
(
connector
.
connector_worker
.
layer_name_to_local_kv_cache_metadata
[
"layer0"
][
0
]
).
data
)
assert
(
unique_tensor
.
data_ptr
()
==
MemoryDesc
.
unpack
(
connector
.
connector_worker
.
layer_name_to_local_kv_cache_metadata
[
"layer1"
][
0
]
).
data
)
assert
(
shared_tensor
.
data_ptr
()
==
MemoryDesc
.
unpack
(
connector
.
connector_worker
.
layer_name_to_local_kv_cache_metadata
[
"layer2"
][
0
]
).
data
)
# Verify engine keys
expected_engine_key
=
f
"
{
ROLE
[
3
:]
}
:
{
IP
}
:
{
DEFAULT_PORT
}
:tp
{
TP_RANK
}
:dp
{
DP_RANK
}
"
assert
(
MemoryDesc
.
unpack
(
connector
.
connector_worker
.
layer_name_to_local_kv_cache_metadata
[
"layer0"
][
0
]
).
engine_key
==
expected_engine_key
)
@
pytest
.
mark
.
skipif
(
not
aiter_available
,
reason
=
"Requires aiter package for ROCm FlashAttention backend"
)
def
test_moriio_handshake_returns_metadata
(
mock_parallel_groups
):
"""MoRIIO handshake socket returns valid agent metadata over ZMQ."""
ROLE
=
"kv_consumer"
vllm_config
=
create_vllm_config
(
role
=
ROLE
)
from
vllm.v1.attention.backends.rocm_aiter_fa
import
AiterFlashAttentionBackend
backend_cls
=
AiterFlashAttentionBackend
# Create test kv cache tensors using proper backend shape
kv_cache_shape
=
backend_cls
.
get_kv_cache_shape
(
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
)
shared_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
unique_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
kv_caches
=
{
"layer0"
:
shared_tensor
,
"layer1"
:
unique_tensor
,
"layer2"
:
shared_tensor
,
}
with
(
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper"
,
FakeMorIIOWrapper
,
),
):
handshake_port
=
_find_free_port
()
# Create connector
vllm_config
.
kv_transfer_config
.
kv_connector_extra_config
.
update
(
{
"proxy_ip"
:
"127.0.0.1"
,
"proxy_ping_port"
:
12345
,
"http_port"
:
12346
,
"handshake_port"
:
handshake_port
,
}
)
connector
=
MoRIIOConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
# Execute register_kv_caches
connector
.
register_kv_caches
(
kv_caches
)
# Connect to handshake socket and request metadata
path
=
make_zmq_path
(
"tcp"
,
"127.0.0.1"
,
handshake_port
)
with
zmq_ctx
(
zmq
.
DEALER
,
path
)
as
sock
:
sock
.
send
(
MoRIIOConstants
.
GET_META_MSG
)
received_frame
=
sock
.
recv_multipart
()
if
len
(
received_frame
)
!=
2
or
received_frame
[
0
]
!=
b
""
:
raise
ValueError
(
f
"Unexpected frame!
{
received_frame
=
}
"
)
metadata_bytes
=
received_frame
[
1
]
decoder
=
msgspec
.
msgpack
.
Decoder
(
MoRIIOAgentMetadata
)
metadata
=
decoder
.
decode
(
metadata_bytes
)
assert
isinstance
(
metadata
,
MoRIIOAgentMetadata
),
(
"Decoded metadata is not MoRIIOAgentMetadata"
)
tests/v1/kv_connector/unit/test_multi_connector.py
View file @
7e63ef82
...
...
@@ -51,6 +51,33 @@ class MockConnector(KVConnectorBase_V1):
)
->
KVConnectorStats
|
None
:
return
MockConnectorStats
(
data
=
data
)
if
data
is
not
None
else
None
def
start_load_kv
(
self
,
forward_context
,
**
kwargs
):
pass
def
wait_for_layer_load
(
self
,
layer_name
):
pass
def
save_kv_layer
(
self
,
layer_name
,
kv_layer
,
attn_metadata
,
**
kwargs
):
pass
def
wait_for_save
(
self
):
pass
def
build_connector_meta
(
self
,
scheduler_output
):
return
None
def
get_num_new_matched_tokens
(
self
,
request
,
num_computed_tokens
):
return
(
0
,
False
)
def
update_state_after_alloc
(
self
,
request
,
blocks
,
num_tokens
)
->
None
:
pass
class
MockCrossLayerConnector
(
MockConnector
):
@
property
def
prefer_cross_layer_blocks
(
self
)
->
bool
:
return
True
# Register the mock connector
KVConnectorFactory
.
register_connector
(
"MockConnector"
,
__name__
,
MockConnector
.
__name__
)
...
...
@@ -603,3 +630,21 @@ class TestMultiConnectorStats:
# One non-empty
stats
.
data
[
"NixlConnector"
].
data
[
"transfer_duration"
].
append
(
1.0
)
assert
not
stats
.
is_empty
()
class
TestMultiConnectorPreferCrossLayerBlocks
:
def
test_all_connectors_prefer_cross_layer_blocks
(
self
):
mc
=
MultiConnector
.
__new__
(
MultiConnector
)
mc
.
_connectors
=
[
MockCrossLayerConnector
.
__new__
(
MockCrossLayerConnector
),
MockCrossLayerConnector
.
__new__
(
MockCrossLayerConnector
),
]
assert
mc
.
prefer_cross_layer_blocks
is
True
def
test_mixed_connectors_do_not_prefer_cross_layer_blocks
(
self
):
mc
=
MultiConnector
.
__new__
(
MultiConnector
)
mc
.
_connectors
=
[
MockCrossLayerConnector
.
__new__
(
MockCrossLayerConnector
),
MockConnector
.
__new__
(
MockConnector
),
# default False
]
assert
mc
.
prefer_cross_layer_blocks
is
False
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
7e63ef82
...
...
@@ -41,10 +41,13 @@ from vllm.distributed.kv_transfer.kv_transfer_state import (
has_kv_transfer_group
,
)
from
vllm.forward_context
import
ForwardContext
from
vllm.outputs
import
RequestOutput
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
Platform
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionBackend
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.output_processor
import
OutputProcessor
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
from
vllm.v1.request
import
RequestStatus
...
...
@@ -182,18 +185,21 @@ class FakeNixlWrapper:
def
_make_fake_nixl_pkg
():
"""Context manager that creates a temporary package making
`from nixl._api import nixl_agent` resolve to our FakeNixlWrapper.
Also creates rixl package for ROCm compatibility.
Automatically cleans up the temporary directory when done.
"""
with
tempfile
.
TemporaryDirectory
()
as
td
:
pkg_root
=
os
.
path
.
join
(
td
,
"nixl"
,
"_api"
)
os
.
makedirs
(
pkg_root
,
exist_ok
=
True
)
# Create both nixl and rixl packages for cross-platform compatibility
for
pkg_name
in
[
"nixl"
,
"rixl"
]:
pkg_root
=
os
.
path
.
join
(
td
,
pkg_name
,
"_api"
)
os
.
makedirs
(
pkg_root
,
exist_ok
=
True
)
# Get the source code of FakeNixlWrapper class and dedent it
fake_nixl_source
=
inspect
.
getsource
(
FakeNixlWrapper
)
fake_nixl_source
=
textwrap
.
dedent
(
fake_nixl_source
)
# Get the source code of FakeNixlWrapper class and dedent it
fake_nixl_source
=
inspect
.
getsource
(
FakeNixlWrapper
)
fake_nixl_source
=
textwrap
.
dedent
(
fake_nixl_source
)
stub
=
f
"""
\
stub
=
f
"""
\
# Copy of FakeNixlWrapper implementation for Ray workers
import uuid
from collections import defaultdict
...
...
@@ -203,16 +209,17 @@ from collections import defaultdict
# Export as nixl_agent
nixl_agent = FakeNixlWrapper
"""
with
open
(
os
.
path
.
join
(
pkg_root
,
"__init__.py"
),
"w"
)
as
f
:
f
.
write
(
stub
)
# Mock nixlXferTelemetry class
pkg_root2
=
os
.
path
.
join
(
td
,
"nixl"
,
"_bindings"
)
os
.
makedirs
(
pkg_root2
,
exist_ok
=
True
)
with
open
(
os
.
path
.
join
(
pkg_root2
,
"__init__.py"
),
"w"
)
as
f
:
f
.
write
(
"class nixlXferTelemetry: pass"
)
# touch parent package
open
(
os
.
path
.
join
(
td
,
"nixl"
,
"__init__.py"
),
"w"
).
close
()
with
open
(
os
.
path
.
join
(
pkg_root
,
"__init__.py"
),
"w"
)
as
f
:
f
.
write
(
stub
)
# Mock nixlXferTelemetry class
pkg_root2
=
os
.
path
.
join
(
td
,
pkg_name
,
"_bindings"
)
os
.
makedirs
(
pkg_root2
,
exist_ok
=
True
)
with
open
(
os
.
path
.
join
(
pkg_root2
,
"__init__.py"
),
"w"
)
as
f
:
f
.
write
(
"class nixlXferTelemetry: pass"
)
# touch parent package
open
(
os
.
path
.
join
(
td
,
pkg_name
,
"__init__.py"
),
"w"
).
close
()
yield
td
...
...
@@ -296,6 +303,7 @@ def test_prompt_less_than_block_size():
)
def
test_kv_transfer_handshake
(
dist_init
):
"""Unit test for basic NixlConnector interface functionality."""
from
vllm.config
import
set_current_vllm_config
# Test setup, we creates a scheduler that contains a NixlConnector
# of role SCHEDULER, and expect it to be serving NixlAgentMetadata from
...
...
@@ -305,81 +313,82 @@ def test_kv_transfer_handshake(dist_init):
vllm_config
.
kv_transfer_config
.
kv_buffer_device
=
"cpu"
scheduler
=
create_scheduler
(
vllm_config
)
# Create two NixlConnector of role WORKER, one is the worker of
# the scheduler (prefill), the other is a worker of decode instance.
# Prefill connector will register KV cache to populate proper handshake
# metadata.
prefill_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
kv_cache_shape
=
FlashAttentionBackend
.
get_kv_cache_shape
(
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
)
shared_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
unique_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
kv_caches
=
{
"layer0"
:
shared_tensor
,
"layer1"
:
unique_tensor
,
"layer2"
:
shared_tensor
,
}
prefill_connector
.
register_kv_caches
(
kv_caches
)
# Simulate EngineCore initialization that would gather connector
# metadata from all workers
metadata
=
prefill_connector
.
get_handshake_metadata
()
# metadata is a NixlHandshakePayload, decode it to get NixlAgentMetadata
decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlAgentMetadata
)
expected_agent_metadata
=
decoder
.
decode
(
metadata
.
agent_metadata_bytes
)
# The scheduler connector expects metadata to be in
# dict[int, KVConnectorHandshakeMetadata], where the first key is
# the dp_rank, the second key is the tp_rank.
scheduler_connector
=
scheduler
.
get_kv_connector
()
scheduler_connector
.
set_xfer_handshake_metadata
({
0
:
metadata
})
with
set_current_vllm_config
(
vllm_config
):
# Create two NixlConnector of role WORKER, one is the worker of
# the scheduler (prefill), the other is a worker of decode instance.
# Simulate a request that finishes prefill, which returns
# corresponding NixlConnectorMetadata for decode instance.
BLOCK_SIZE
=
vllm_config
.
cache_config
.
block_size
NUM_EXTERNAL_FULL_BLOCKS
=
2
NUM_TOKENS
=
int
(
BLOCK_SIZE
*
(
NUM_EXTERNAL_FULL_BLOCKS
+
0.5
))
request
=
create_request
(
request_id
=
1
,
block_size
=
BLOCK_SIZE
,
num_tokens
=
NUM_TOKENS
,
do_remote_decode
=
True
,
)
request
.
status
=
RequestStatus
.
FINISHED_LENGTH_CAPPED
delay
,
kv_connector_metadata
=
scheduler
.
get_kv_connector
().
request_finished
(
request
,
[
0
,
1
,
2
]
)
assert
delay
# Decode connector will be able to create handshake with the prefill connector.
decode_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
# Here we are testing the retrieval of NIXLAgentMetadata.
# Knowing the implementation detail, we override the add_remote_agent
# to validate the metadata received is the same as the one in prefill_connector.
with
patch
.
object
(
decode_connector
.
connector_worker
,
"add_remote_agent"
)
as
mock_add_remote_agent
:
mock_add_remote_agent
.
return_type
=
"remote_agent"
decode_connector
.
connector_worker
.
_nixl_handshake
(
kv_connector_metadata
[
"remote_host"
],
kv_connector_metadata
[
"remote_port"
],
kv_connector_metadata
[
"tp_size"
],
kv_connector_metadata
[
"remote_engine_id"
],
# Prefill connector will register KV cache to populate proper handshake
# metadata.
prefill_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
kv_cache_shape
=
FlashAttentionBackend
.
get_kv_cache_shape
(
num_blocks
=
2
,
block_size
=
16
,
num_kv_heads
=
4
,
head_size
=
64
)
shared_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
unique_tensor
=
torch
.
zeros
(
*
kv_cache_shape
,
dtype
=
torch
.
float16
)
kv_caches
=
{
"layer0"
:
shared_tensor
,
"layer1"
:
unique_tensor
,
"layer2"
:
shared_tensor
,
}
prefill_connector
.
register_kv_caches
(
kv_caches
)
# Simulate EngineCore initialization that would gather connector
# metadata from all workers
metadata
=
prefill_connector
.
get_handshake_metadata
()
# metadata is a NixlHandshakePayload, decode it to get NixlAgentMetadata
decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlAgentMetadata
)
expected_agent_metadata
=
decoder
.
decode
(
metadata
.
agent_metadata_bytes
)
# The scheduler connector expects metadata to be in
# dict[int, KVConnectorHandshakeMetadata], where the first key is
# the dp_rank, the second key is the tp_rank.
scheduler_connector
=
scheduler
.
get_kv_connector
()
scheduler_connector
.
set_xfer_handshake_metadata
({
0
:
metadata
})
# Simulate a request that finishes prefill, which returns
# corresponding NixlConnectorMetadata for decode instance.
BLOCK_SIZE
=
vllm_config
.
cache_config
.
block_size
NUM_EXTERNAL_FULL_BLOCKS
=
2
NUM_TOKENS
=
int
(
BLOCK_SIZE
*
(
NUM_EXTERNAL_FULL_BLOCKS
+
0.5
))
request
=
create_request
(
request_id
=
1
,
block_size
=
BLOCK_SIZE
,
num_tokens
=
NUM_TOKENS
,
do_remote_decode
=
True
,
)
request
.
status
=
RequestStatus
.
FINISHED_LENGTH_CAPPED
delay
,
kv_connector_metadata
=
scheduler
.
get_kv_connector
().
request_finished
(
request
,
[
0
,
1
,
2
]
)
assert
delay
# Decode connector will be able to create handshake with the prefill connector.
decode_connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
# Here we are testing the retrieval of NIXLAgentMetadata.
# Knowing the implementation detail, we override the add_remote_agent
# to validate the metadata received is the same as the one in prefill_connector.
with
patch
.
object
(
decode_connector
.
connector_worker
,
"add_remote_agent"
)
as
mock_add_remote_agent
:
mock_add_remote_agent
.
return_type
=
"remote_agent"
decode_connector
.
connector_worker
.
_nixl_handshake
(
kv_connector_metadata
[
"remote_host"
],
kv_connector_metadata
[
"remote_port"
],
kv_connector_metadata
[
"tp_size"
],
kv_connector_metadata
[
"remote_engine_id"
],
)
received_metadata
=
mock_add_remote_agent
.
call_args
.
args
assert
received_metadata
[
0
]
==
expected_agent_metadata
assert
received_metadata
[
1
]
==
0
# remote_tp_rank
assert
received_metadata
[
2
]
==
1
# remote_tp_size
received_metadata
=
mock_add_remote_agent
.
call_args
.
args
assert
received_metadata
[
0
]
==
expected_agent_metadata
assert
received_metadata
[
1
]
==
0
# remote_tp_rank
assert
received_metadata
[
2
]
==
1
# remote_tp_size
# Need to shutdown the background thread to release NIXL side channel port
scheduler_connector
.
shutdown
()
# Need to shutdown the background thread to release NIXL side channel port
scheduler_connector
.
shutdown
()
class
FakeNixlConnectorWorker
(
NixlConnectorWorker
):
...
...
@@ -391,6 +400,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_hand_shake_latency
=
hand_shake_latency
self
.
kv_cache_layout
=
kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
self
.
src_xfer_handles_by_block_size
=
{
self
.
block_size
:
1
}
def
_nixl_handshake
(
self
,
host
:
str
,
port
:
int
,
remote_tp_size
:
int
,
expected_engine_id
:
str
...
...
@@ -407,22 +418,43 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
assert
expected_engine_id
==
self
.
REMOTE_ENGINE_ID
remote_agent_name
=
self
.
add_remote_agent
(
NixlAgentMetadata
(
engine_id
=
self
.
REMOTE_ENGINE_ID
,
agent_metadata
=
FakeNixlWrapper
.
AGENT_METADATA
,
kv_caches_base_addr
=
[
0
],
device_id
=
0
,
num_blocks
=
1
,
block_lens
=
self
.
block_len_per_layer
,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout
=
"HND"
,
block_size
=
self
.
block_size
,
),
remote_tp_size
=
remote_tp_size
,
)
return
{
0
:
remote_agent_name
}
# Adjust remote block length metadata to satisfy heterogeneous TP
# invariants enforced during handshake validation.
remote_block_lens
=
list
(
self
.
block_len_per_layer
)
tp_ratio
=
self
.
kv_topo
.
tp_ratio
(
remote_tp_size
)
if
remote_tp_size
>
self
.
world_size
:
# P TP > D TP case, block_len of remote is smaller
remote_block_lens
=
[
block_len
//
(
-
tp_ratio
)
for
block_len
in
remote_block_lens
]
elif
remote_tp_size
<
self
.
world_size
:
remote_block_lens
=
[
block_len
*
tp_ratio
for
block_len
in
remote_block_lens
]
# When remote tp_size > local tp_size, handshake with multiple
# remote ranks.
num_hanshakes
=
1
if
tp_ratio
>
0
else
-
tp_ratio
remote_agents
:
dict
[
int
,
str
]
=
{}
for
remote_tp_rank
in
range
(
num_hanshakes
):
remote_agent_name
=
self
.
add_remote_agent
(
NixlAgentMetadata
(
engine_id
=
self
.
REMOTE_ENGINE_ID
,
agent_metadata
=
FakeNixlWrapper
.
AGENT_METADATA
,
kv_caches_base_addr
=
[
0
],
device_id
=
remote_tp_rank
,
num_blocks
=
1
,
block_lens
=
remote_block_lens
,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout
=
"HND"
,
block_size
=
self
.
block_size
,
),
remote_tp_rank
=
remote_tp_rank
,
remote_tp_size
=
remote_tp_size
,
)
remote_agents
[
remote_tp_rank
]
=
remote_agent_name
return
remote_agents
class
TestNixlHandshake
:
...
...
@@ -432,6 +464,7 @@ class TestNixlHandshake:
)
def
test_multi_xfer_one_engine
(
self
,
default_vllm_config
,
# dist_init is a fixture that initializes the distributed environment.
dist_init
,
):
...
...
@@ -453,7 +486,13 @@ class TestNixlHandshake:
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
assert
isinstance
(
connector
.
connector_worker
.
nixl_wrapper
,
FakeNixlWrapper
)
connector
.
connector_worker
.
nixl_wrapper
.
set_cycles_before_xfer_done
(
3
)
worker
=
connector
.
connector_worker
worker
.
nixl_wrapper
.
set_cycles_before_xfer_done
(
3
)
# simulate handshake
worker
.
dst_xfer_side_handles
=
{
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
:
{
0
:
1
}
}
worker
.
kv_cache_layout
=
"HND"
num_xfers
=
4
while
True
:
# For the same request_id, initiate multiple xfers across different
...
...
@@ -515,6 +554,7 @@ class TestNixlHandshake:
)
def
test_async_load_kv
(
self
,
default_vllm_config
,
# Fixture that initializes the distributed environment.
dist_init
,
# Simulate consumer-producer TP sizes.
...
...
@@ -567,12 +607,178 @@ class TestNixlHandshake:
return
raise
TimeoutError
(
"Took too long to complete async handshake."
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
,
)
@
pytest
.
mark
.
parametrize
(
"local_tp_size"
,
[
1
,
2
])
def
test_prefill_tp_size_greater_than_decode_tp_size
(
self
,
local_tp_size
:
int
,
default_vllm_config
,
dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations.
"""
vllm_config
=
create_vllm_config
()
local_tp_size
=
1
vllm_config
.
parallel_config
.
tensor_parallel_size
=
local_tp_size
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
worker
=
connector
.
connector_worker
# Minimal local registration params used by add_remote_agent
worker
.
slot_size_per_layer
=
[
4096
]
worker
.
block_len_per_layer
=
[
4096
*
worker
.
block_size
]
worker
.
num_blocks
=
1
worker
.
dst_num_blocks
[
worker
.
engine_id
]
=
worker
.
num_blocks
worker
.
src_blocks_data
=
[(
0
,
worker
.
block_len_per_layer
[
0
],
worker
.
tp_rank
)]
def
check_handshake
(
remote_tp_size
:
int
):
tp_ratio
=
remote_tp_size
//
local_tp_size
assert
set
(
remote_agents
.
keys
())
==
set
(
range
(
tp_ratio
))
remote_engine_id
=
worker
.
REMOTE_ENGINE_ID
assert
worker
.
_tp_size
[
remote_engine_id
]
==
remote_tp_size
assert
-
tp_ratio
==
worker
.
kv_topo
.
tp_ratio_from_engine_id
(
remote_engine_id
)
# ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks
assert
-
tp_ratio
in
worker
.
src_xfer_handles_by_tp_ratio
assert
len
(
worker
.
src_xfer_handles_by_tp_ratio
[
-
tp_ratio
])
==
tp_ratio
assert
remote_engine_id
in
worker
.
dst_xfer_side_handles
assert
set
(
worker
.
dst_xfer_side_handles
[
remote_engine_id
].
keys
())
==
set
(
range
(
tp_ratio
)
)
remote_agents
=
worker
.
_nixl_handshake
(
host
=
"localhost"
,
port
=
1234
,
remote_tp_size
=
2
,
expected_engine_id
=
worker
.
REMOTE_ENGINE_ID
,
)
check_handshake
(
2
)
# NOTE flexiblity: a second remote with higher number of ranks is
# discovered. This is not a scenario we actively support right now, but
# the connector allows it.
worker
.
REMOTE_ENGINE_ID
=
"remote_engine_2"
remote_agents
=
worker
.
_nixl_handshake
(
host
=
"localhost"
,
port
=
1234
,
remote_tp_size
=
6
,
expected_engine_id
=
worker
.
REMOTE_ENGINE_ID
,
)
check_handshake
(
6
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
,
)
@
pytest
.
mark
.
parametrize
(
"local_tp_size"
,
[
1
,
2
])
def
test_prefill_tp_size_greater_than_decode_tp_size_mla
(
self
,
local_tp_size
:
int
,
default_vllm_config
,
dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations for an MLA model.
"""
vllm_config
=
create_vllm_config
()
d_tp_size
=
1
p_tp_size
=
2
# Build two separate connectors/workers to emulate P TP=2 ranks.
conn_p0
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
conn_p1
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
conn_p0
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
conn_p0
.
engine_id
,
hand_shake_latency
=
0
)
conn_p1
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
conn_p1
.
engine_id
,
hand_shake_latency
=
0
)
# Force P world size to 2 for both workers and emulate distinct tp_ranks.
# Also enable MLA path so that expected_finished_count is updated.
for
rank
,
worker
in
enumerate
(
(
conn_p0
.
connector_worker
,
conn_p1
.
connector_worker
)
):
worker
.
world_size
=
p_tp_size
worker
.
kv_topo
.
remote_tp_size
=
{
worker
.
engine_id
:
p_tp_size
}
worker
.
tp_rank
=
rank
worker
.
use_mla
=
True
req_id
=
"req-ep-dp2-p0"
now
=
time
.
perf_counter
()
# Register a request on P that is waiting for consumers to read
# (both workers track it).
conn_p0
.
connector_worker
.
_reqs_to_send
[
req_id
]
=
now
+
10.0
conn_p0
.
connector_worker
.
_reqs_to_process
.
add
(
req_id
)
conn_p1
.
connector_worker
.
_reqs_to_send
[
req_id
]
=
now
+
10.0
conn_p1
.
connector_worker
.
_reqs_to_process
.
add
(
req_id
)
# Simulate a read notification coming from D with (tp=1, dp=2).
notif
=
f
"
{
req_id
}
:
{
d_tp_size
}
"
.
encode
()
# D0-0->P0 notif
conn_p0
.
connector_worker
.
nixl_wrapper
.
get_new_notifs
=
lambda
:
{
"agent"
:
[
notif
]
}
# type: ignore[method-assign]
conn_p1
.
connector_worker
.
nixl_wrapper
.
get_new_notifs
=
lambda
:
{
"agent"
:
[
notif
]
}
# type: ignore[method-assign]
# Trigger notification processing via get_finished().
done_sending0
,
_
=
conn_p0
.
get_finished
(
finished_req_ids
=
set
())
done_sending1
,
_
=
conn_p1
.
get_finished
(
finished_req_ids
=
set
())
assert
req_id
in
done_sending0
and
req_id
in
done_sending1
# E2E aggregation: ensure the aggregated output marks the request
# as finished using the connector's expected_finished_count.
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
aggregator
=
KVOutputAggregator
.
from_connector
(
conn_p0
,
world_size
=
2
)
out0
=
ModelRunnerOutput
(
req_ids
=
[
req_id
],
req_id_to_index
=
{
req_id
:
0
},
sampled_token_ids
=
[[
0
]],
logprobs
=
None
,
prompt_logprobs_dict
=
{},
pooler_output
=
[
None
],
kv_connector_output
=
KVConnectorOutput
(
finished_sending
=
done_sending0
,
finished_recving
=
None
,
),
)
out1
=
ModelRunnerOutput
(
req_ids
=
[
req_id
],
req_id_to_index
=
{
req_id
:
0
},
sampled_token_ids
=
[[
0
]],
logprobs
=
None
,
prompt_logprobs_dict
=
{},
pooler_output
=
[
None
],
kv_connector_output
=
KVConnectorOutput
(
finished_sending
=
done_sending1
,
finished_recving
=
None
,
),
)
aggregated
=
aggregator
.
aggregate
([
out0
,
out1
],
output_rank
=
0
)
assert
aggregated
.
kv_connector_output
is
not
None
assert
aggregated
.
kv_connector_output
.
finished_sending
==
{
req_id
}
# Producers cleaned up state for the finished request.
assert
req_id
not
in
conn_p0
.
connector_worker
.
_reqs_to_send
assert
req_id
not
in
conn_p0
.
connector_worker
.
_reqs_to_process
assert
req_id
not
in
conn_p1
.
connector_worker
.
_reqs_to_send
assert
req_id
not
in
conn_p1
.
connector_worker
.
_reqs_to_process
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
,
)
def
test_concurrent_load_kv
(
self
,
default_vllm_config
,
# dist_init is a fixture that initializes the distributed environment.
dist_init
,
):
...
...
@@ -585,6 +791,9 @@ class TestNixlHandshake:
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
)
# Register (mocked) local xfer handler
# worker = connector.connector_worker
# worker.src_xfer_handles_by_block_size = {worker.block_size: 1}
metadata
=
NixlConnectorMetadata
()
total_reqs
=
5
for
i
in
range
(
total_reqs
):
...
...
@@ -630,7 +839,9 @@ class TestNixlHandshake:
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
,
)
def
test_handshake_fails_on_kv_cache_layout_mismatch
(
self
,
dist_init
):
def
test_handshake_fails_on_kv_cache_layout_mismatch
(
self
,
default_vllm_config
,
dist_init
):
"""
Verify that adding a remote agent fails if kv_cache_layout differs.
This test is only relevant for heterogeneous TP.
...
...
@@ -672,7 +883,6 @@ class TestNixlHandshake:
with
pytest
.
raises
(
RuntimeError
):
# mismatched layout is expected to fail
worker
.
add_remote_agent
(
meta
,
remote_tp_size
=
2
)
with
pytest
.
raises
(
AssertionError
):
worker
.
add_remote_agent
(
meta
,
remote_tp_size
=
1
)
@
patch
(
...
...
@@ -680,7 +890,7 @@ class TestNixlHandshake:
FakeNixlWrapper
,
)
def
test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental
(
self
,
dist_init
self
,
default_vllm_config
,
dist_init
):
"""
Verify that adding a remote agent fails if kv_cache_layout differs.
...
...
@@ -735,7 +945,7 @@ class TestNixlHandshake:
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
,
)
def
test_kv_connector_stats
(
dist_init
):
def
test_kv_connector_stats
(
default_vllm_config
,
dist_init
):
"""Test that KV transfer stats are properly recorded and retrieved."""
vllm_config
=
create_vllm_config
()
...
...
@@ -1069,6 +1279,22 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
run_test_and_cleanup
()
class
RequestIdMapper
:
"""Helper class to map external request IDs to internal request IDs."""
def
__init__
(
self
,
output_processor
:
OutputProcessor
):
self
.
req_id_mapping
:
dict
[
str
,
str
]
=
{}
self
.
original_add_request
=
output_processor
.
add_request
output_processor
.
add_request
=
self
.
_add_request
def
_add_request
(
self
,
request
:
EngineCoreRequest
,
*
args
,
**
kwargs
):
self
.
req_id_mapping
[
request
.
external_req_id
]
=
request
.
request_id
return
self
.
original_add_request
(
request
,
*
args
,
**
kwargs
)
def
__call__
(
self
,
external_req_id
:
str
)
->
str
:
return
self
.
req_id_mapping
[
external_req_id
]
def
_run_abort_timeout_test
(
llm
:
LLM
,
timeout
:
int
):
"""Helper function to run the abort timeout test logic."""
remote_prefill_opts
=
{
...
...
@@ -1090,24 +1316,34 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
0
].
req_to_blocks
id_mapper
=
RequestIdMapper
(
llm
.
llm_engine
.
output_processor
)
def
req_id
(
outputs
:
list
[
RequestOutput
])
->
str
:
assert
len
(
outputs
)
==
1
return
id_mapper
(
outputs
[
0
].
request_id
)
padding
=
"Just making this request a little longer so that we're sure "
"we're not hitting the small-request lower bound beneath which we don't "
"actually trigger the whole kv transfer, but rather just recompute the "
"blocks on D."
_
=
llm
.
generate
([
f
"What is the capital of Japan?
{
padding
}
"
],
sampling_params
)
req0_id
=
req_id
(
llm
.
generate
([
f
"What is the capital of Japan?
{
padding
}
"
],
sampling_params
)
)
# Request finished but not freed
assert
"0"
in
scheduler
.
finished_req_ids
and
"0"
in
req_to_blocks
assert
req0_id
in
scheduler
.
finished_req_ids
and
req0_id
in
req_to_blocks
# Some other request, 0 still not freed
_
=
llm
.
generate
([
f
"What is the capital of Italy?
{
padding
}
"
],
sampling_params
)
assert
"0"
in
req_to_blocks
assert
"1"
in
scheduler
.
finished_req_ids
and
"1"
in
req_to_blocks
req1_id
=
req_id
(
llm
.
generate
([
f
"What is the capital of Italy?
{
padding
}
"
],
sampling_params
)
)
assert
req0_id
in
req_to_blocks
assert
req1_id
in
scheduler
.
finished_req_ids
and
req1_id
in
req_to_blocks
# Wait for timeout and trigger another scheduler loop
time
.
sleep
(
timeout
)
_
=
llm
.
generate
([
f
"What is the capital of France?
{
padding
}
"
],
sampling_params
)
# Request-0 times out and is cleared!
assert
"0"
not
in
req_to_blocks
assert
req0_id
not
in
req_to_blocks
# Need to shutdown the background thread to release NIXL side channel port
llm
.
llm_engine
.
engine_core
.
shutdown
()
...
...
@@ -1132,7 +1368,7 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
"TRITON_ATTN"
,
],
)
def
test_register_kv_caches
(
dist_init
,
attn_backend
,
monkeypatch
):
def
test_register_kv_caches
(
default_vllm_config
,
dist_init
,
attn_backend
):
"""
Test that register_kv_caches() properly calls nixl_wrapper methods with
correct data.
...
...
@@ -1144,9 +1380,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
block layout info
"""
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
vllm_config
=
create_vllm_config
()
vllm_config
=
create_vllm_config
(
attention_backend
=
attn_backend
)
# Import the appropriate backend based on the parameter
if
attn_backend
==
"FLASH_ATTN"
:
...
...
@@ -1205,7 +1439,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
patch
(
f
"
{
nixl_module
}
.NixlWrapper"
)
as
mock_nixl_wrapper
,
patch
(
f
"
{
nixl_module
}
.threading.Event"
),
patch
(
f
"
{
nixl_module
}
.threading.Thread"
)
as
mock_thread
,
patch
(
f
"
{
nixl_module
}
.get_attn_backend"
)
as
mock_get_attn_backend
,
patch
(
f
"
{
nixl_module
}
.get_
current_
attn_backend"
)
as
mock_get_attn_backend
,
):
# Ensure get_attn_backend returns the correct value due to
# _cached_get_attn_backend returning the backend from previous
...
...
@@ -1295,7 +1529,9 @@ class FakePlatform(Platform):
(
"oot"
,
"VRAM"
),
],
)
def
test_kv_buffer_to_nixl_memory_types
(
dist_init
,
kv_buffer_device
,
nixl_memory_type
):
def
test_kv_buffer_to_nixl_memory_types
(
default_vllm_config
,
dist_init
,
kv_buffer_device
,
nixl_memory_type
):
"""
Test that register_kv_caches() passes the correct memory types from the
config to the nixl_wrapper.
...
...
@@ -1340,7 +1576,7 @@ def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device, nixl_memory
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
,
)
def
test_shutdown_cleans_up_resources
(
dist_init
):
def
test_shutdown_cleans_up_resources
(
default_vllm_config
,
dist_init
):
"""Test that shutdown() properly cleans up all resources."""
vllm_config
=
create_vllm_config
()
...
...
@@ -1359,8 +1595,11 @@ def test_shutdown_cleans_up_resources(dist_init):
patch
.
object
(
nixl_wrapper
,
"deregister_memory"
)
as
mock_dereg
,
):
worker
.
_recving_transfers
=
{
"req1"
:
[
123
]}
worker
.
src_xfer_side_handle
=
456
worker
.
dst_xfer_side_handles
=
{
"engine1"
:
789
}
# Mock register_kv_cache which registers local handle
worker
.
src_xfer_handles_by_block_size
=
{
worker
.
block_size
:
455
}
# P TP = 2 * D TP case, we should register 2 local handles
worker
.
src_xfer_handles_by_tp_ratio
=
{
-
2
:
[
456
,
457
]}
worker
.
dst_xfer_side_handles
=
{
"engine1"
:
{
0
:
789
}}
worker
.
_remote_agents
=
{
"engine1"
:
{
0
:
"agent1"
}}
worker
.
_registered_descs
=
[
"desc1"
,
"desc2"
]
...
...
@@ -1381,8 +1620,10 @@ def test_shutdown_cleans_up_resources(dist_init):
mock_listener
.
join
.
assert_called_once
()
mock_rel_xfer
.
assert_called_once_with
(
123
)
assert
mock_rel_dlist
.
call_count
==
2
mock_rel_dlist
.
assert_any_call
(
456
)
# src handle
assert
mock_rel_dlist
.
call_count
==
4
mock_rel_dlist
.
assert_any_call
(
455
)
# src handle (whole region)
mock_rel_dlist
.
assert_any_call
(
456
)
# src handle (1st chunk)
mock_rel_dlist
.
assert_any_call
(
457
)
# src handle (2nd chunk)
mock_rel_dlist
.
assert_any_call
(
789
)
# dst handle
mock_rem_agent
.
assert_called_once_with
(
"agent1"
)
assert
mock_dereg
.
call_count
==
2
...
...
@@ -1394,7 +1635,7 @@ def test_shutdown_cleans_up_resources(dist_init):
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
,
)
def
test_aborted_request_removed_from_worker_in_batch
(
dist_init
):
def
test_aborted_request_removed_from_worker_in_batch
(
default_vllm_config
,
dist_init
):
"""
Create and schedule a request so that P adds it to in-batch tracking via
the real scheduler, then simulate an abort (request not in next scheduler
...
...
@@ -1464,6 +1705,8 @@ class FailingNixlWrapper(FakeNixlWrapper):
self
.
fail_handshake
=
False
self
.
fail_transfer_setup
=
False
self
.
fail_send_notif
=
False
self
.
fail_transfer_state
=
False
# Returns "ERR" state
self
.
fail_transfer_exception
=
False
# Raises exception in check_xfer_state
def
add_remote_agent
(
self
,
agent_metadata
:
bytes
)
->
str
:
if
self
.
fail_handshake
:
...
...
@@ -1498,12 +1741,156 @@ class FailingNixlWrapper(FakeNixlWrapper):
raise
RuntimeError
(
"Simulated send_notif failure"
)
return
super
().
send_notif
(
agent_name
,
notif_msg
)
def
check_xfer_state
(
self
,
handle
:
int
)
->
str
:
if
self
.
fail_transfer_exception
:
raise
RuntimeError
(
"Simulated check_xfer_state exception"
)
if
self
.
fail_transfer_state
:
return
"ERR"
# Bad transfer state
return
super
().
check_xfer_state
(
handle
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FailingNixlWrapper
,
)
@
pytest
.
mark
.
parametrize
(
"failure_type,wrapper_config,needs_get_finished"
,
[
(
"transfer_setup_failed"
,
{
"fail_transfer_setup"
:
True
},
False
),
(
"handshake_failed"
,
{
"fail_handshake"
:
True
},
False
),
(
"notification_failed"
,
{
"fail_send_notif"
:
True
},
False
),
(
"transfer_failed"
,
{
"fail_transfer_state"
:
True
},
True
),
(
"transfer_exception"
,
{
"fail_transfer_exception"
:
True
},
True
),
],
)
def
test_transfer_failure_logging
(
default_vllm_config
,
dist_init
,
failure_type
,
wrapper_config
,
needs_get_finished
,
):
"""Test that transfer failures are logged with structured context.
Run with `pytest -sv` to see the log output.
Covers failure types:
- transfer_setup_failed: make_prepped_xfer fails
- handshake_failed: add_remote_agent fails during request handshake
- notification_failed: send_notif fails
- transfer_failed: check_xfer_state returns bad state (e.g., "ERR")
- transfer_exception: check_xfer_state raises exception
"""
import
logging
vllm_config
=
create_vllm_config
()
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0.0
)
# Configure FailingNixlWrapper to fail in the specified way
for
key
,
value
in
wrapper_config
.
items
():
setattr
(
connector
.
connector_worker
.
nixl_wrapper
,
key
,
value
)
request_id
=
f
"test_
{
failure_type
}
_req"
# For notification_failed, we need empty local blocks
# (full cache hit path to trigger send_notif)
local_blocks
=
[]
if
failure_type
==
"notification_failed"
else
[
10
,
11
,
12
]
remote_blocks
=
[
20
,
21
,
22
]
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req_to_recv
(
request_id
=
request_id
,
local_block_ids
=
local_blocks
,
kv_transfer_params
=
{
"remote_block_ids"
:
remote_blocks
,
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"remote_tp_size"
:
1
,
},
)
connector
.
bind_connector_metadata
(
metadata
)
dummy_ctx
=
ForwardContext
(
no_compile_layers
=
{},
attn_metadata
=
{},
virtual_engine
=
0
,
)
# Capture logs from the nixl_connector logger specifically
# vLLM loggers have propagate=False, so we need to capture directly
nixl_logger
=
logging
.
getLogger
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector"
)
captured_logs
:
list
[
logging
.
LogRecord
]
=
[]
class
LogCapture
(
logging
.
Handler
):
def
emit
(
self
,
record
):
captured_logs
.
append
(
record
)
handler
=
LogCapture
()
handler
.
setLevel
(
logging
.
ERROR
)
nixl_logger
.
addHandler
(
handler
)
try
:
connector
.
start_load_kv
(
dummy_ctx
)
# Process the ready_requests queue (for async handshake)
connector
.
bind_connector_metadata
(
NixlConnectorMetadata
())
# Wait for async handshake to complete
time
.
sleep
(
0.2
)
connector
.
start_load_kv
(
dummy_ctx
)
# For transfer_failed/transfer_exception, the error happens in
# get_finished() when checking transfer state
if
needs_get_finished
:
connector
.
get_finished
(
finished_req_ids
=
set
())
finally
:
nixl_logger
.
removeHandler
(
handler
)
# Print logs for manual comparison between commits
error_logs
=
[
r
for
r
in
captured_logs
if
r
.
levelno
>=
logging
.
ERROR
]
print
(
"
\n
"
+
"="
*
60
)
print
(
f
"CAPTURED ERROR LOGS for
{
failure_type
}
:"
)
print
(
"="
*
60
)
for
i
,
record
in
enumerate
(
error_logs
):
print
(
f
"
\n
--- Log
{
i
+
1
}
---"
)
print
(
f
"Message:
{
record
.
message
}
"
)
print
(
"="
*
60
+
"
\n
"
)
assert
len
(
error_logs
)
>=
1
,
f
"Expected at least one error log for
{
failure_type
}
"
# Verify structured logging output (new format)
# Check that at least one log matches the expected format
all_messages
=
[
r
.
message
for
r
in
error_logs
]
combined_logs
=
"
\n
"
.
join
(
all_messages
)
assert
any
(
"NIXL transfer failure"
in
msg
for
msg
in
all_messages
),
(
f
"Expected structured log format with 'NIXL transfer failure' prefix "
f
"for
{
failure_type
}
. Got:
{
all_messages
}
"
)
assert
any
(
"failure_type"
in
msg
for
msg
in
all_messages
),
(
f
"Expected 'failure_type' in logs. Got:
{
all_messages
}
"
)
assert
any
(
"Context:"
in
msg
for
msg
in
all_messages
),
(
f
"Expected 'Context:' in logs. Got:
{
all_messages
}
"
)
# Check that the expected failure_type appears in at least one log
# Note: handshake_failed also triggers handshake_setup_failed
assert
failure_type
in
combined_logs
or
(
failure_type
==
"handshake_failed"
and
"handshake_setup_failed"
in
combined_logs
),
f
"Expected '
{
failure_type
}
' in logs. Got:
{
all_messages
}
"
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FailingNixlWrapper
,
)
def
test_handshake_failure_returns_finished
(
dist_init
):
def
test_handshake_failure_returns_finished
(
default_vllm_config
,
dist_init
):
"""Test that handshake failures mark blocks invalid and return via get_finished."""
vllm_config
=
create_vllm_config
()
...
...
@@ -1552,7 +1939,7 @@ def test_handshake_failure_returns_finished(dist_init):
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FailingNixlWrapper
,
)
def
test_transfer_setup_failure_returns_finished
(
dist_init
):
def
test_transfer_setup_failure_returns_finished
(
default_vllm_config
,
dist_init
):
"""Test that transfer setup failures mark blocks invalid
and return via get_finished."""
vllm_config
=
create_vllm_config
()
...
...
@@ -1627,6 +2014,7 @@ def test_transfer_setup_failure_returns_finished(dist_init):
FakeNixlWrapper
,
)
def
test_compatibility_hash_validation
(
default_vllm_config
,
dist_init
,
mismatch_type
,
config_overrides
,
...
...
@@ -1739,7 +2127,7 @@ def test_compatibility_hash_validation(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
,
)
def
test_handshake_decode_errors
(
dist_init
,
error_scenario
):
def
test_handshake_decode_errors
(
default_vllm_config
,
dist_init
,
error_scenario
):
"""
Test that msgspec decode errors are properly handled during handshake.
...
...
tests/v1/kv_connector/unit/test_offloading_connector.py
View file @
7e63ef82
...
...
@@ -26,6 +26,7 @@ from vllm.v1.core.kv_cache_utils import (
init_none_hash
,
)
from
vllm.v1.core.sched.scheduler
import
Scheduler
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_offload.abstract
import
(
LoadStoreSpec
,
OffloadingEvent
,
...
...
@@ -64,8 +65,11 @@ class MockLoadStoreSpec(LoadStoreSpec):
class
MockOffloadingHandler
(
OffloadingHandler
):
def
__init__
(
self
):
self
.
transfer_specs
:
dict
[
int
,
TransferSpec
]
=
{}
self
.
completed_transfers
:
list
[
TransferResult
]
=
[]
self
.
completed_specs
:
list
[
TransferSpec
]
=
[]
self
.
waiting_jobs
:
set
[
int
]
=
set
()
self
.
completed_jobs
:
list
[
int
]
=
[]
self
.
flushed_jobs
:
set
[
int
]
=
set
()
def
get_finished
(
self
)
->
list
[
TransferResult
]:
finished
=
self
.
completed_transfers
...
...
@@ -73,14 +77,25 @@ class MockOffloadingHandler(OffloadingHandler):
return
finished
def
transfer_async
(
self
,
job_id
:
int
,
spec
:
TransferSpec
)
->
bool
:
self
.
completed_specs
.
append
(
spec
)
self
.
completed_transfers
.
append
((
job_id
,
True
)
)
self
.
transfer_specs
[
job_id
]
=
spec
self
.
waiting_jobs
.
add
(
job_id
)
return
True
def
complete_jobs
(
self
,
job_ids
:
set
[
int
])
->
None
:
for
job_id
in
job_ids
:
if
job_id
in
self
.
waiting_jobs
:
self
.
waiting_jobs
.
remove
(
job_id
)
self
.
completed_jobs
.
append
(
job_id
)
self
.
completed_transfers
.
append
((
job_id
,
True
))
def
wait
(
self
,
job_ids
:
set
[
int
])
->
None
:
self
.
flushed_jobs
|=
job_ids
self
.
complete_jobs
(
job_ids
)
class
MockOffloadingSpec
(
OffloadingSpec
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
):
super
().
__init__
(
vllm_config
)
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
kv_cache_config
:
KVCacheConfig
):
super
().
__init__
(
vllm_config
,
kv_cache_config
)
self
.
manager
=
MagicMock
(
spec
=
OffloadingManager
)
self
.
manager
.
lookup
.
return_value
=
0
...
...
@@ -98,9 +113,22 @@ class MockOffloadingSpec(OffloadingSpec):
yield
GPULoadStoreSpec
,
MockLoadStoreSpec
,
self
.
handler
yield
MockLoadStoreSpec
,
GPULoadStoreSpec
,
self
.
handler
def
complete_transfers
(
self
):
self
.
handler
.
complete_jobs
(
self
.
handler
.
waiting_jobs
.
copy
())
def
get_completed_transfers
(
self
)
->
list
[
TransferSpec
]:
specs
=
self
.
handler
.
completed_specs
self
.
handler
.
completed_specs
=
[]
specs
=
[
self
.
handler
.
transfer_specs
[
job_id
]
for
job_id
in
self
.
handler
.
completed_jobs
]
self
.
handler
.
completed_jobs
.
clear
()
return
specs
def
get_flushed_transfers
(
self
):
specs
=
[
self
.
handler
.
transfer_specs
[
job_id
]
for
job_id
in
self
.
handler
.
flushed_jobs
]
self
.
handler
.
flushed_jobs
.
clear
()
return
specs
...
...
@@ -170,11 +198,9 @@ class RequestRunner:
# mapping (offloading address) -> gpu_block_index
self
.
offloaded
:
dict
[
Any
,
int
]
=
{}
self
.
pending_loads_count
:
int
=
0
self
.
pending_stores_count
:
int
=
0
self
.
completed_loads
:
list
[
TransferSummary
]
=
[]
self
.
completed_stores
:
list
[
TransferSummary
]
=
[]
self
.
flushed_gpu_block_indexes
:
set
[
int
]
=
set
()
# maps {block_id: block_offset}
self
.
gpu_block_index
:
dict
[
int
,
int
]
=
{}
...
...
@@ -201,54 +227,60 @@ class RequestRunner:
self
.
scheduler
.
add_request
(
req
)
def
_wait_for_transfers
(
self
):
def
_parse_transfers
(
self
):
for
transfer_spec
in
self
.
offloading_spec
.
get_flushed_transfers
():
src_spec
,
dst_spec
=
transfer_spec
assert
isinstance
(
src_spec
,
GPULoadStoreSpec
)
for
block_id
in
src_spec
.
block_ids
:
self
.
flushed_gpu_block_indexes
.
add
(
self
.
gpu_block_index
[
block_id
.
item
()]
)
block_size_factor
=
self
.
offloaded_block_size
//
self
.
gpu_block_size
while
self
.
pending_loads_count
or
self
.
pending_stores_count
:
for
transfer_spec
in
self
.
offloading_spec
.
get_completed_transfers
():
src_spec
,
dst_spec
=
transfer_spec
if
isinstance
(
src_spec
,
GPULoadStoreSpec
):
store
=
True
gpu_spec
=
src_spec
offload_spec
=
dst_spec
else
:
store
=
False
gpu_spec
=
dst_spec
offload_spec
=
src_spec
assert
isinstance
(
offload_spec
,
MockLoadStoreSpec
)
assert
isinstance
(
gpu_spec
,
GPULoadStoreSpec
)
gpu_block_indices
:
list
[
int
]
=
[]
for
block_id
in
gpu_spec
.
block_ids
:
gpu_block_indices
.
append
(
self
.
gpu_block_index
[
block_id
.
item
()])
# list of (block_hash, sub_block_offset)
offload_addresses
:
list
[
Any
]
=
[]
for
block_hash
in
offload_spec
.
block_hashes
:
for
sub_block_idx
in
range
(
block_size_factor
):
offload_addresses
.
append
((
block_hash
,
sub_block_idx
))
if
store
:
assert
len
(
gpu_block_indices
)
==
len
(
offload_addresses
)
self
.
completed_stores
.
append
(
TransferSummary
(
gpu_block_indices
,
offload_addresses
)
)
self
.
pending_stores_count
-=
1
else
:
remainder_sub_block_count
=
len
(
offload_addresses
)
-
len
(
gpu_block_indices
)
assert
remainder_sub_block_count
>=
0
assert
remainder_sub_block_count
<
block_size_factor
offload_addresses
=
offload_addresses
[
remainder_sub_block_count
:]
self
.
completed_loads
.
append
(
TransferSummary
(
gpu_block_indices
,
offload_addresses
)
)
self
.
pending_loads_count
-=
1
for
transfer_spec
in
self
.
offloading_spec
.
get_completed_transfers
():
src_spec
,
dst_spec
=
transfer_spec
if
isinstance
(
src_spec
,
GPULoadStoreSpec
):
store
=
True
gpu_spec
=
src_spec
offload_spec
=
dst_spec
else
:
store
=
False
gpu_spec
=
dst_spec
offload_spec
=
src_spec
assert
isinstance
(
offload_spec
,
MockLoadStoreSpec
)
assert
isinstance
(
gpu_spec
,
GPULoadStoreSpec
)
gpu_block_indices
:
list
[
int
]
=
[]
for
block_id
in
gpu_spec
.
block_ids
:
gpu_block_indices
.
append
(
self
.
gpu_block_index
[
block_id
.
item
()])
# list of (block_hash, sub_block_offset)
offload_addresses
:
list
[
Any
]
=
[]
for
block_hash
in
offload_spec
.
block_hashes
:
for
sub_block_idx
in
range
(
block_size_factor
):
offload_addresses
.
append
((
block_hash
,
sub_block_idx
))
if
store
:
assert
len
(
gpu_block_indices
)
==
len
(
offload_addresses
)
self
.
completed_stores
.
append
(
TransferSummary
(
gpu_block_indices
,
offload_addresses
)
)
else
:
remainder_sub_block_count
=
len
(
offload_addresses
)
-
len
(
gpu_block_indices
)
assert
remainder_sub_block_count
>=
0
assert
remainder_sub_block_count
<
block_size_factor
offload_addresses
=
offload_addresses
[
remainder_sub_block_count
:]
self
.
completed_loads
.
append
(
TransferSummary
(
gpu_block_indices
,
offload_addresses
)
)
def
_update_gpu_block_idx
(
self
):
for
blocks
in
self
.
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
...
...
@@ -257,18 +289,19 @@ class RequestRunner:
for
block_idx
,
block
in
enumerate
(
blocks
):
self
.
gpu_block_index
[
block
.
block_id
]
=
block_idx
def
_run
(
self
,
decoded_tokens
:
list
[
int
]):
def
_run
(
self
,
decoded_tokens
:
list
[
int
]
,
complete_transfers
:
bool
):
"""
Runs multiple engine (scheduler + worker) steps.
Assumes a single request is running.
Args:
decoded_tokens: the tokens to yield at each step.
complete_transfers: complete transfers immediately
"""
tokens_iter
=
iter
(
decoded_tokens
)
token_id
=
next
(
tokens_iter
,
None
)
while
token_id
is
not
Non
e
:
while
Tru
e
:
assert
self
.
scheduler
.
requests
scheduler_output
=
self
.
scheduler
.
schedule
()
...
...
@@ -278,8 +311,10 @@ class RequestRunner:
assert
kv_connector_metadata
is
not
None
assert
isinstance
(
kv_connector_metadata
,
OffloadingConnectorMetadata
)
self
.
pending_loads_count
+=
len
(
kv_connector_metadata
.
reqs_to_load
)
self
.
pending_stores_count
+=
len
(
kv_connector_metadata
.
reqs_to_store
)
if
scheduler_output
.
preempted_req_ids
:
self
.
worker_connector
.
handle_preemptions
(
scheduler_output
.
preempted_req_ids
)
self
.
worker_connector
.
bind_connector_metadata
(
kv_connector_metadata
)
self
.
worker_connector
.
start_load_kv
(
self
.
_dummy_ctx
)
...
...
@@ -287,6 +322,9 @@ class RequestRunner:
if
scheduler_output
.
total_num_scheduled_tokens
>
0
:
self
.
worker_connector
.
wait_for_save
()
if
complete_transfers
:
self
.
offloading_spec
.
complete_transfers
()
finished_sending
,
finished_recving
=
self
.
worker_connector
.
get_finished
(
scheduler_output
.
finished_req_ids
)
...
...
@@ -297,7 +335,7 @@ class RequestRunner:
reqs
=
self
.
scheduler
.
running
,
finished_sending
=
finished_sending
,
finished_recving
=
finished_recving
,
token_id
=
token_id
,
token_id
=
token_id
or
0
,
)
if
self
.
scheduler
.
running
:
...
...
@@ -305,7 +343,10 @@ class RequestRunner:
self
.
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
self
.
_wait_for_transfers
()
if
token_id
is
None
:
break
self
.
_parse_transfers
()
# run one more step to update finished stored
if
EOS_TOKEN_ID
in
decoded_tokens
:
...
...
@@ -330,8 +371,10 @@ class RequestRunner:
def
run
(
self
,
decoded_tokens
:
list
[
int
],
complete_transfers
:
bool
=
True
,
expected_stored_gpu_block_indexes
:
tuple
[
int
,
...]
=
(),
expected_loaded_gpu_block_indexes
:
tuple
[
int
,
...]
=
(),
expected_flushed_gpu_block_indexes
:
tuple
[
int
,
...]
=
(),
):
"""
Runs multiple engine (scheduler + worker) steps.
...
...
@@ -339,14 +382,17 @@ class RequestRunner:
Args:
decoded_tokens: the tokens to yield at each step.
complete_transfers: complete transfers immediately
expected_stored_gpu_block_indexes: GPU block indexes
that are expected to be written during the run.
expected_loaded_gpu_block_indexes: GPU block indexes
that are expected to be loaded during the run.
expected_flushed_gpu_block_indexes: GPU block indexes
that are expected to be flushed during the run.
"""
self
.
manager
.
reset_mock
()
self
.
_run
(
decoded_tokens
)
self
.
_run
(
decoded_tokens
,
complete_transfers
)
loaded_gpu_block_indexes
:
set
[
int
]
=
set
()
for
transfer
in
self
.
completed_loads
:
...
...
@@ -370,6 +416,9 @@ class RequestRunner:
assert
set
(
expected_stored_gpu_block_indexes
)
==
stored_gpu_block_indexes
self
.
completed_stores
.
clear
()
assert
set
(
expected_flushed_gpu_block_indexes
)
==
self
.
flushed_gpu_block_indexes
self
.
flushed_gpu_block_indexes
.
clear
()
@
pytest
.
fixture
def
request_runner
():
...
...
@@ -414,10 +463,13 @@ def test_offloading_connector(request_runner):
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
list
(
block_hashes
)[
1
:
2
])
)
runner
.
run
(
decoded_tokens
=
[
0
]
,
expected_stored_gpu_block_indexes
=
(
3
,
4
,
5
)
)
runner
.
run
(
decoded_tokens
=
[
0
])
# add block missing 1 token -> no offload
runner
.
run
(
decoded_tokens
=
[
0
]
*
(
offloaded_block_size
-
1
))
runner
.
run
(
decoded_tokens
=
[
0
]
*
(
offloaded_block_size
-
1
),
expected_stored_gpu_block_indexes
=
(
3
,
4
,
5
),
)
runner
.
manager
.
prepare_store
.
assert_not_called
()
# +1 token -> single block, fail prepare_store
...
...
@@ -435,23 +487,20 @@ def test_offloading_connector(request_runner):
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
block_hashes
)
)
runner
.
run
(
decoded_tokens
=
[
0
]
*
offloaded_block_size
,
expected_stored_gpu_block_indexes
=
(
15
,
16
,
17
),
)
runner
.
run
(
decoded_tokens
=
[
0
]
*
offloaded_block_size
)
runner
.
manager
.
touch
.
assert_called
()
block_hashes1
=
list
(
runner
.
manager
.
touch
.
call_args
.
args
[
0
])
assert
len
(
block_hashes1
)
==
6
# terminate request
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
])
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_stored_gpu_block_indexes
=
(
15
,
16
,
17
),
)
# create a new request differing only on the last token
runner
.
new_request
(
token_ids
=
[
0
]
*
(
offloaded_block_size
*
6
-
1
)
+
[
1
])
runner
.
run
(
decoded_tokens
=
[
0
],
expected_stored_gpu_block_indexes
=
tuple
(
range
(
6
*
block_size_factor
)),
)
runner
.
run
(
decoded_tokens
=
[
0
])
runner
.
manager
.
touch
.
assert_called
()
block_hashes2
=
list
(
runner
.
manager
.
touch
.
call_args
.
args
[
0
])
assert
len
(
block_hashes2
)
==
6
...
...
@@ -461,7 +510,10 @@ def test_offloading_connector(request_runner):
assert
block_hashes1
[
5
]
!=
block_hashes2
[
5
]
# terminate request
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
])
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_stored_gpu_block_indexes
=
tuple
(
range
(
6
*
block_size_factor
)),
)
# full_block_tokens - num_computed_tokens < offloaded_block_size
runner
.
new_request
(
...
...
@@ -528,7 +580,74 @@ def test_offloading_connector(request_runner):
assert
event
.
token_ids
==
[]
assert
event
.
parent_block_hash
is
None
assert
event
.
lora_id
is
None
assert
event
.
lora_name
is
None
event
=
events
[
1
]
assert
isinstance
(
event
,
BlockRemoved
)
assert
event
.
block_hashes
==
to_hashes
([
4
,
5
,
6
])
assert
event
.
medium
==
"B"
def
test_request_preemption
(
request_runner
):
offloaded_block_size
=
12
gpu_block_size
=
4
num_gpu_blocks
=
100
runner
=
request_runner
(
offloaded_block_size
=
offloaded_block_size
,
gpu_block_size
=
gpu_block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
)
free_block_queue
=
runner
.
scheduler
.
kv_cache_manager
.
block_pool
.
free_block_queue
num_free_blocks_empty
=
free_block_queue
.
num_free_blocks
# 2 blocks, store all, without flushing
# blocks = [0, 1, 2], [3, 4, 5]
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
*
2
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
block_hashes
)
)
runner
.
run
(
decoded_tokens
=
[
0
],
complete_transfers
=
False
,
)
# decode 2 more blocks - 1 gpu block, storing [6, 7, 8] (no flush)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
block_hashes
)
)
runner
.
run
(
decoded_tokens
=
[
0
]
*
(
2
*
offloaded_block_size
-
gpu_block_size
),
complete_transfers
=
False
,
)
# simulate KV cache running out of space
free_block_queue
.
num_free_blocks
=
0
# request should be preempted now
runner
.
run
(
decoded_tokens
=
[],
complete_transfers
=
False
,
expected_flushed_gpu_block_indexes
=
(
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
),
expected_stored_gpu_block_indexes
=
(
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
),
)
# restore KV cache space and reset GPU prefix cache
free_block_queue
.
num_free_blocks
=
num_free_blocks_empty
runner
.
scheduler
.
reset_prefix_cache
()
# request should now return from preemption
# re-load [0, ..., 8] from the CPU and store [9, 10, 11]
runner
.
manager
.
lookup
.
return_value
=
3
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
block_hashes
)
)
runner
.
run
(
decoded_tokens
=
[
0
]
*
gpu_block_size
,
expected_loaded_gpu_block_indexes
=
(
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
),
)
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_stored_gpu_block_indexes
=
(
9
,
10
,
11
),
)
tests/v1/kv_connector/unit/utils.py
View file @
7e63ef82
...
...
@@ -11,6 +11,7 @@ import torch
from
vllm
import
SamplingParams
from
vllm.config
import
(
AttentionConfig
,
CacheConfig
,
DeviceConfig
,
KVTransferConfig
,
...
...
@@ -94,6 +95,7 @@ def create_vllm_config(
dtype
:
str
=
"float16"
,
cache_dtype
:
str
=
"auto"
,
hf_overrides
:
dict
[
str
,
Any
]
|
None
=
None
,
attention_backend
:
str
|
None
=
None
,
)
->
VllmConfig
:
"""Initialize VllmConfig For Testing."""
model_config
=
ModelConfig
(
...
...
@@ -131,12 +133,14 @@ def create_vllm_config(
enable_permute_local_kv
=
enable_permute_local_kv
,
kv_connector_extra_config
=
kv_connector_extra_config
or
{},
)
attention_config
=
AttentionConfig
(
backend
=
attention_backend
)
return
VllmConfig
(
scheduler_config
=
scheduler_config
,
model_config
=
model_config
,
cache_config
=
cache_config
,
kv_transfer_config
=
kv_transfer_config
,
device_config
=
DeviceConfig
(
"cpu"
),
attention_config
=
attention_config
,
)
...
...
@@ -151,7 +155,13 @@ def create_scheduler(
kv_cache_tensors
=
[],
kv_cache_groups
=
[
KVCacheGroupSpec
(
[
"layer"
],
FullAttentionSpec
(
block_size
,
1
,
1
,
torch
.
float32
,
False
)
[
"layer"
],
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
1
,
head_size
=
1
,
dtype
=
torch
.
float32
,
),
)
],
)
...
...
tests/v1/kv_offload/test_cpu_gpu.py
View file @
7e63ef82
...
...
@@ -7,6 +7,7 @@ import pytest
import
torch
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionBackend
from
vllm.v1.kv_offload.mediums
import
CPULoadStoreSpec
,
GPULoadStoreSpec
from
vllm.v1.kv_offload.worker.cpu_gpu
import
CpuGpuOffloadingHandlers
...
...
@@ -49,6 +50,7 @@ NUM_MAPPINGS = [3]
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_transfer
(
default_vllm_config
,
gpu_to_cpu
:
bool
,
num_mappings
:
int
,
head_size
:
int
,
...
...
@@ -62,7 +64,7 @@ def test_transfer(
seed
:
int
,
device
:
str
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
set_random_seed
(
seed
)
# create per-layer GPU KV caches based on available attn_backends
attn_backends_list
=
BACKENDS_TO_TEST
...
...
tests/v1/kv_offload/test_cpu_offloading.py
View file @
7e63ef82
...
...
@@ -13,13 +13,12 @@ from vllm import LLM, SamplingParams, TokensPrompt
from
vllm.config
import
KVEventsConfig
,
KVTransferConfig
from
vllm.distributed.kv_events
import
BlockStored
,
KVEventBatch
from
vllm.platforms
import
current_platform
from
vllm.utils.system_utils
import
set_env_var
CPU_BLOCK_SIZES
=
[
48
]
ATTN_BACKENDS
=
[
"FLASH_ATTN"
]
ATTN_BACKENDS
=
[]
if
current_platform
.
is_cuda
():
ATTN_BACKENDS
.
append
(
"FLASHINFER"
)
ATTN_BACKENDS
=
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"TRITON_ATTN"
]
elif
current_platform
.
is_rocm
():
ATTN_BACKENDS
=
[
"TRITON_ATTN"
]
...
...
@@ -162,7 +161,7 @@ def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None:
kv_connector
=
"OffloadingConnector"
,
kv_role
=
"kv_both"
,
kv_connector_extra_config
=
{
"
num_
cpu_b
locks
"
:
1
000
,
"cpu_b
ytes_to_use
"
:
5
00
<<
2
0
,
"block_size"
:
cpu_block_size
,
},
)
...
...
@@ -180,13 +179,13 @@ def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None:
topic
=
"test"
,
)
with
set_env_var
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
):
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
,
gpu_memory_utilization
=
0.5
,
kv_events
_config
=
kv_
events
_config
,
kv_transfer_config
=
kv_transfer_config
,
)
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
,
gpu_memory_utilization
=
0.5
,
kv_events_config
=
kv_events_config
,
kv_transfer
_config
=
kv_
transfer
_config
,
attention_config
=
{
"backend"
:
attn_backend
}
,
)
events_endpoint
=
events_endpoint
.
replace
(
"*"
,
"127.0.0.1"
)
subscriber
=
MockSubscriber
(
events_endpoint
,
topic
=
kv_events_config
.
topic
)
...
...
tests/v1/kv_offload/test_worker.py
View file @
7e63ef82
...
...
@@ -63,6 +63,12 @@ class OffloadingHandler1To2(OffloadingHandler):
del
self
.
transfers
[
job_id
]
return
finished
def
wait
(
self
,
job_ids
:
set
[
int
])
->
None
:
for
job_id
in
job_ids
:
spec
=
self
.
transfers
.
get
(
job_id
)
if
spec
:
assert
spec
.
finished
class
OffloadingHandler2To1
(
OffloadingHandler
):
def
__init__
(
self
):
...
...
@@ -84,6 +90,12 @@ class OffloadingHandler2To1(OffloadingHandler):
del
self
.
transfers
[
job_id
]
return
finished
def
wait
(
self
,
job_ids
:
set
[
int
])
->
None
:
for
job_id
in
job_ids
:
spec
=
self
.
transfers
.
get
(
job_id
)
if
spec
:
assert
spec
.
finished
def
test_offloading_worker
():
"""
...
...
tests/v1/metrics/test_perf_metrics.py
0 → 100644
View file @
7e63ef82
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for the analytic estimators in metrics/flops.py.
"""
import
types
from
types
import
SimpleNamespace
from
transformers.models.deepseek_v3.configuration_deepseek_v3
import
DeepseekV3Config
from
transformers.models.llama4.configuration_llama4
import
(
Llama4Config
,
Llama4TextConfig
,
)
from
transformers.models.qwen3.configuration_qwen3
import
Qwen3Config
from
transformers.models.qwen3_moe.configuration_qwen3_moe
import
Qwen3MoeConfig
from
vllm.config.model
import
ModelConfig
,
get_hf_text_config
from
vllm.transformers_utils.model_arch_config_convertor
import
(
MODEL_ARCH_CONFIG_CONVERTORS
,
ModelArchConfigConvertorBase
,
)
from
vllm.v1.metrics.perf
import
(
AttentionMetrics
,
BaseConfigParser
,
ExecutionContext
,
FfnMetrics
,
ModelMetrics
,
ParsedArgs
,
UnembedMetrics
,
)
class
MockModelConfig
:
"""Mock ModelConfig that implements the getter methods used by parsers."""
def
__init__
(
self
,
hf_config
,
dtype
):
self
.
hf_config
=
hf_config
self
.
hf_text_config
=
get_hf_text_config
(
hf_config
)
convertor_cls
=
MODEL_ARCH_CONFIG_CONVERTORS
.
get
(
self
.
hf_config
.
model_type
,
ModelArchConfigConvertorBase
)
self
.
model_arch_config
=
convertor_cls
(
self
.
hf_config
,
self
.
hf_text_config
).
convert
()
self
.
dtype
=
dtype
self
.
is_attention_free
=
False
def
__getattr__
(
self
,
name
):
# 1. Check if ModelConfig actually has this attribute
if
not
hasattr
(
ModelConfig
,
name
):
raise
AttributeError
(
f
"'
{
type
(
self
).
__name__
}
' object has no attribute '
{
name
}
' "
f
"and neither does 'ModelConfig'."
)
# 2. Fetch the attribute from the ModelConfig CLASS
attr
=
getattr
(
ModelConfig
,
name
)
# 3. Case A: It is a @property
if
isinstance
(
attr
,
property
):
# Manually invoke the property's getter, passing 'self' (this mock instance)
return
attr
.
__get__
(
self
,
self
.
__class__
)
# 4. Case B: It is a standard method (function)
if
isinstance
(
attr
,
types
.
FunctionType
):
# Bind the function to 'self' so it acts like a method of
# this instance. This creates a bound method where 'self' is
# automatically passed as the first arg.
return
types
.
MethodType
(
attr
,
self
)
# 5. Case C: It is a class attribute / static variable
return
attr
def
create_mock_vllm_config
(
hf_config
,
model_dtype
=
"bfloat16"
,
cache_dtype
=
"auto"
,
quant_config
=
None
,
data_parallel_size
=
1
,
tensor_parallel_size
=
1
,
pipeline_parallel_size
=
1
,
enable_expert_parallel
=
False
,
)
->
SimpleNamespace
:
vllm_config
=
SimpleNamespace
()
vllm_config
.
model_config
=
MockModelConfig
(
hf_config
,
model_dtype
)
vllm_config
.
cache_config
=
SimpleNamespace
()
vllm_config
.
cache_config
.
cache_dtype
=
cache_dtype
vllm_config
.
quant_config
=
quant_config
vllm_config
.
parallel_config
=
SimpleNamespace
()
vllm_config
.
parallel_config
.
data_parallel_size
=
data_parallel_size
vllm_config
.
parallel_config
.
tensor_parallel_size
=
tensor_parallel_size
vllm_config
.
parallel_config
.
pipeline_parallel_size
=
pipeline_parallel_size
vllm_config
.
parallel_config
.
enable_expert_parallel
=
enable_expert_parallel
return
vllm_config
#### Parser Tests ####
def
test_base_config_parser
():
"""Test BaseConfigParser extracts base model attributes correctly."""
hf_config
=
Qwen3Config
(
vocab_size
=
50000
,
hidden_size
=
2048
,
num_attention_heads
=
16
,
num_hidden_layers
=
24
,
)
vllm_config
=
create_mock_vllm_config
(
hf_config
,
model_dtype
=
"float16"
)
parser
=
BaseConfigParser
()
args
=
ParsedArgs
()
result
=
parser
.
parse
(
args
,
vllm_config
)
assert
result
.
vocab_size
==
50000
assert
result
.
hidden_size
==
2048
assert
result
.
num_attention_heads
==
16
assert
result
.
num_hidden_layers
==
24
assert
result
.
weight_byte_size
==
2
# float16 is 2 bytes
assert
result
.
activation_byte_size
==
2
# default activation size
def
test_base_attention_config_parser_with_gqa
():
"""Test BaseAttentionConfigParser with grouped query attention."""
hf_config
=
Qwen3Config
(
hidden_size
=
4096
,
num_attention_heads
=
32
,
num_key_value_heads
=
8
,
# GQA with 4:1 ratio
head_dim
=
128
,
)
vllm_config
=
create_mock_vllm_config
(
hf_config
)
parser_chain
=
AttentionMetrics
.
get_parser
()
result
=
parser_chain
.
parse
(
vllm_config
)
assert
result
.
num_key_value_heads
==
8
assert
result
.
head_dim
==
128
def
test_base_attention_config_parser_without_gqa
():
"""
Test BaseAttentionConfigParser defaults to MHA when num_key_value_heads not
specified.
"""
hf_config
=
Qwen3Config
(
hidden_size
=
4096
,
num_attention_heads
=
32
,
# No num_key_value_heads specified
)
vllm_config
=
create_mock_vllm_config
(
hf_config
)
parser_chain
=
AttentionMetrics
.
get_parser
()
result
=
parser_chain
.
parse
(
vllm_config
)
# Should default to MHA (num_key_value_heads = num_attention_heads)
assert
result
.
num_key_value_heads
==
32
def
test_base_ffn_config_parser_dense
():
"""Test BaseFfnConfigParser for dense FFN."""
hf_config
=
Qwen3Config
(
hidden_size
=
4096
,
intermediate_size
=
11008
,
num_hidden_layers
=
32
,
)
vllm_config
=
create_mock_vllm_config
(
hf_config
)
parser_chain
=
FfnMetrics
.
get_parser
()
result
=
parser_chain
.
parse
(
vllm_config
)
assert
result
.
intermediate_size
==
11008
assert
result
.
num_experts
==
0
assert
result
.
num_experts_per_tok
==
0
assert
result
.
num_moe_layers
==
0
# No MoE
def
test_base_ffn_config_parser_moe
():
"""Test BaseFfnConfigParser for MoE FFN."""
hf_config
=
Qwen3MoeConfig
(
hidden_size
=
4096
,
intermediate_size
=
11008
,
num_hidden_layers
=
32
,
num_experts
=
64
,
num_experts_per_tok
=
8
,
moe_intermediate_size
=
14336
,
n_shared_experts
=
2
,
)
vllm_config
=
create_mock_vllm_config
(
hf_config
)
parser_chain
=
FfnMetrics
.
get_parser
()
result
=
parser_chain
.
parse
(
vllm_config
)
assert
result
.
num_experts
==
64
assert
result
.
num_experts_per_tok
==
8
assert
result
.
moe_intermediate_size
==
14336
assert
result
.
num_shared_experts
==
2
assert
result
.
num_moe_layers
==
32
# All layers are MoE by default
def
test_interleave_moe_layer_step_parser
():
"""Test InterleaveMoeLayerStepParser correctly computes MoE layer count."""
hf_config
=
Llama4Config
(
text_config
=
Llama4TextConfig
(
num_hidden_layers
=
32
,
num_local_experts
=
64
,
interleave_moe_layer_step
=
4
,
# Every 4th layer is MoE
),
)
vllm_config
=
create_mock_vllm_config
(
hf_config
)
parser_chain
=
FfnMetrics
.
get_parser
()
result
=
parser_chain
.
parse
(
vllm_config
)
assert
result
.
num_moe_layers
==
8
def
test_moe_layer_freq_parser
():
"""Test MoeLayerFreqParser correctly computes MoE layer count."""
hf_config
=
DeepseekV3Config
(
num_hidden_layers
=
30
,
n_routed_experts
=
64
,
moe_layer_freq
=
3
,
# Every 3rd layer after first_k_dense_replace
first_k_dense_replace
=
6
,
# First 6 layers are dense
)
vllm_config
=
create_mock_vllm_config
(
hf_config
)
parser_chain
=
FfnMetrics
.
get_parser
()
result
=
parser_chain
.
parse
(
vllm_config
)
# Layers >= 6 and divisible by 3: 6, 9, 12, 15, 18, 21, 24, 27
expected_moe_layers
=
len
(
[
layer
for
layer
in
range
(
30
)
if
layer
>=
6
and
layer
%
3
==
0
]
)
assert
expected_moe_layers
==
8
assert
result
.
num_moe_layers
==
expected_moe_layers
#### ComponentMetrics Tests ####
def
test_attention_metrics_scaling
():
"""Test that attention metrics scale proportionally with model dimensions."""
base_hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
num_attention_heads
=
16
,
num_key_value_heads
=
16
,
num_hidden_layers
=
12
,
head_dim
=
128
,
)
base_vllm_config
=
create_mock_vllm_config
(
base_hf_config
)
base_metrics
=
AttentionMetrics
.
from_vllm_config
(
base_vllm_config
)
# Test scaling with number of layers
double_layers_hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
num_attention_heads
=
16
,
num_key_value_heads
=
16
,
num_hidden_layers
=
24
,
# Double the layers
head_dim
=
128
,
)
double_layers_vllm_config
=
create_mock_vllm_config
(
double_layers_hf_config
)
double_layers_metrics
=
AttentionMetrics
.
from_vllm_config
(
double_layers_vllm_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
100
,
context_len
=
512
,
is_prefill
=
True
)
# FLOPS should double when layers double
base_flops
=
base_metrics
.
get_num_flops
(
ctx
)
double_flops
=
double_layers_metrics
.
get_num_flops
(
ctx
)
assert
double_flops
==
2
*
base_flops
# Read/write bytes should also scale proportionally
base_read
=
base_metrics
.
get_read_bytes
(
ctx
)
double_read
=
double_layers_metrics
.
get_read_bytes
(
ctx
)
assert
double_read
==
2
*
base_read
base_write
=
base_metrics
.
get_write_bytes
(
ctx
)
double_write
=
double_layers_metrics
.
get_write_bytes
(
ctx
)
assert
double_write
==
2
*
base_write
def
test_attention_metrics_grouped_query
():
"""Test attention metrics handle grouped query attention correctly."""
mha_hf_config
=
Qwen3Config
(
hidden_size
=
4096
,
num_attention_heads
=
32
,
num_key_value_heads
=
32
,
# MHA
num_hidden_layers
=
1
,
)
mha_config
=
create_mock_vllm_config
(
mha_hf_config
)
gqa_hf_config
=
Qwen3Config
(
hidden_size
=
4096
,
num_attention_heads
=
32
,
num_key_value_heads
=
8
,
# GQA with 4:1 ratio
num_hidden_layers
=
1
,
)
gqa_config
=
create_mock_vllm_config
(
gqa_hf_config
)
mha_metrics
=
AttentionMetrics
.
from_vllm_config
(
mha_config
)
gqa_metrics
=
AttentionMetrics
.
from_vllm_config
(
gqa_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
1
,
context_len
=
1024
,
is_prefill
=
False
)
# GQA should have less KV cache reads since fewer KV heads
mha_read
=
mha_metrics
.
get_read_bytes
(
ctx
)
gqa_read
=
gqa_metrics
.
get_read_bytes
(
ctx
)
assert
gqa_read
<
mha_read
def
test_ffn_metrics_scaling
():
"""Test FFN metrics scale proportionally with model dimensions."""
base_hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
intermediate_size
=
8192
,
num_hidden_layers
=
12
,
)
base_vllm_config
=
create_mock_vllm_config
(
base_hf_config
)
base_metrics
=
FfnMetrics
.
from_vllm_config
(
base_vllm_config
)
# Test scaling with intermediate size
larger_ffn_hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
intermediate_size
=
16384
,
# Double intermediate size
num_hidden_layers
=
12
,
)
larger_ffn_vllm_config
=
create_mock_vllm_config
(
larger_ffn_hf_config
)
larger_ffn_metrics
=
FfnMetrics
.
from_vllm_config
(
larger_ffn_vllm_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
100
,
context_len
=
512
,
is_prefill
=
True
)
# FLOPS should double when intermediate size doubles
base_flops
=
base_metrics
.
get_num_flops
(
ctx
)
larger_flops
=
larger_ffn_metrics
.
get_num_flops
(
ctx
)
assert
larger_flops
==
base_flops
*
2
def
test_moe_metrics_vs_dense
():
"""Test MoE metrics versus dense metrics."""
dense_hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
intermediate_size
=
8192
,
num_hidden_layers
=
12
,
)
dense_config
=
create_mock_vllm_config
(
dense_hf_config
)
moe_hf_config
=
Qwen3MoeConfig
(
hidden_size
=
2048
,
intermediate_size
=
8192
,
num_hidden_layers
=
12
,
num_experts
=
64
,
num_experts_per_tok
=
2
,
# 2 routed expert
moe_intermediate_size
=
8192
,
n_shared_experts
=
0
,
)
moe_config
=
create_mock_vllm_config
(
moe_hf_config
)
dense_metrics
=
FfnMetrics
.
from_vllm_config
(
dense_config
)
moe_metrics
=
FfnMetrics
.
from_vllm_config
(
moe_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
100
,
context_len
=
512
,
is_prefill
=
True
)
# MoE should have different compute/memory characteristics
dense_flops
=
dense_metrics
.
get_num_flops
(
ctx
)
moe_flops
=
moe_metrics
.
get_num_flops
(
ctx
)
# 2 routed experts vs 1 dense.
assert
moe_flops
==
dense_flops
*
2
def
test_unembed_metrics_scaling
():
"""Test unembedding metrics scale with vocab size."""
small_vocab_hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
vocab_size
=
32000
,
)
small_vocab_config
=
create_mock_vllm_config
(
small_vocab_hf_config
)
large_vocab_hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
vocab_size
=
64000
,
# Double vocab size
)
large_vocab_config
=
create_mock_vllm_config
(
large_vocab_hf_config
)
small_vocab_metrics
=
UnembedMetrics
.
from_vllm_config
(
small_vocab_config
)
large_vocab_metrics
=
UnembedMetrics
.
from_vllm_config
(
large_vocab_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
100
,
context_len
=
512
,
is_prefill
=
True
)
# FLOPS should double when vocab size doubles
small_flops
=
small_vocab_metrics
.
get_num_flops
(
ctx
)
large_flops
=
large_vocab_metrics
.
get_num_flops
(
ctx
)
assert
large_flops
==
2
*
small_flops
def
test_prefill_vs_decode_differences
():
"""Test that prefill and decode have different memory access patterns."""
hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
num_attention_heads
=
16
,
num_key_value_heads
=
16
,
num_hidden_layers
=
1
,
)
config
=
create_mock_vllm_config
(
hf_config
)
metrics
=
AttentionMetrics
.
from_vllm_config
(
config
)
prefill_ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
512
,
context_len
=
512
,
is_prefill
=
True
)
decode_ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
1
,
context_len
=
512
,
is_prefill
=
False
)
prefill_read
=
metrics
.
get_read_bytes
(
prefill_ctx
)
decode_read
=
metrics
.
get_read_bytes
(
decode_ctx
)
assert
prefill_read
!=
decode_read
def
test_model_metrics_aggregation
():
"""Test ModelMetrics correctly aggregates across components."""
hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
num_attention_heads
=
16
,
num_hidden_layers
=
12
,
vocab_size
=
32000
,
intermediate_size
=
8192
,
)
config
=
create_mock_vllm_config
(
hf_config
)
model_metrics
=
ModelMetrics
(
config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
100
,
context_len
=
512
,
is_prefill
=
True
)
# Should have metrics for attention, ffn, and unembed
total_flops
=
model_metrics
.
get_num_flops
(
ctx
)
breakdown
=
model_metrics
.
get_num_flops_breakdown
(
ctx
)
# Breakdown should sum to total
assert
total_flops
==
sum
(
breakdown
.
values
())
def
test_moe_expert_activation_proportional_scaling
():
"""Test that routed expert metrics scale proportionally with num_experts_per_tok."""
base_moe_config
=
Qwen3MoeConfig
(
hidden_size
=
2048
,
intermediate_size
=
8192
,
num_hidden_layers
=
12
,
num_experts
=
64
,
num_experts_per_tok
=
1
,
# 1 expert per token
moe_intermediate_size
=
8192
,
n_shared_experts
=
2
,
)
double_experts_config
=
Qwen3MoeConfig
(
hidden_size
=
2048
,
intermediate_size
=
8192
,
num_hidden_layers
=
12
,
num_experts
=
64
,
num_experts_per_tok
=
2
,
# 2 experts per token (double)
moe_intermediate_size
=
8192
,
n_shared_experts
=
2
,
# Same shared experts
)
triple_experts_config
=
Qwen3MoeConfig
(
hidden_size
=
2048
,
intermediate_size
=
8192
,
num_hidden_layers
=
12
,
num_experts
=
64
,
num_experts_per_tok
=
3
,
# 3 experts per token (triple)
moe_intermediate_size
=
8192
,
n_shared_experts
=
2
,
# Same shared experts
)
base_vllm_config
=
create_mock_vllm_config
(
base_moe_config
)
double_vllm_config
=
create_mock_vllm_config
(
double_experts_config
)
triple_vllm_config
=
create_mock_vllm_config
(
triple_experts_config
)
base_metrics
=
FfnMetrics
.
from_vllm_config
(
base_vllm_config
)
double_metrics
=
FfnMetrics
.
from_vllm_config
(
double_vllm_config
)
triple_metrics
=
FfnMetrics
.
from_vllm_config
(
triple_vllm_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
100
,
context_len
=
512
,
is_prefill
=
True
)
# Get total metrics - the key insight is that differences should be proportional
base_flops
=
base_metrics
.
get_num_flops
(
ctx
)
double_flops
=
double_metrics
.
get_num_flops
(
ctx
)
triple_flops
=
triple_metrics
.
get_num_flops
(
ctx
)
# The difference between double and base should equal one additional expert
one_expert_diff
=
double_flops
-
base_flops
# The difference between triple and base should equal two additional experts
two_expert_diff
=
triple_flops
-
base_flops
# Proportional scaling: 2 * (1 expert diff) should equal (2 expert diff)
assert
two_expert_diff
==
2
*
one_expert_diff
# Same logic applies to memory operations
base_read
=
base_metrics
.
get_read_bytes
(
ctx
)
double_read
=
double_metrics
.
get_read_bytes
(
ctx
)
triple_read
=
triple_metrics
.
get_read_bytes
(
ctx
)
one_expert_read_diff
=
double_read
-
base_read
two_expert_read_diff
=
triple_read
-
base_read
assert
two_expert_read_diff
==
2
*
one_expert_read_diff
# Same for write bytes
base_write
=
base_metrics
.
get_write_bytes
(
ctx
)
double_write
=
double_metrics
.
get_write_bytes
(
ctx
)
triple_write
=
triple_metrics
.
get_write_bytes
(
ctx
)
one_expert_write_diff
=
double_write
-
base_write
two_expert_write_diff
=
triple_write
-
base_write
assert
two_expert_write_diff
==
2
*
one_expert_write_diff
def
test_quantization_config_parser_fp8
():
"""Test quantization parsers with fp8."""
class
MockQuantConfig
:
def
get_name
(
self
):
return
"fp8"
hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
num_attention_heads
=
16
,
num_hidden_layers
=
1
)
vllm_config
=
create_mock_vllm_config
(
hf_config
,
quant_config
=
MockQuantConfig
())
attn_result
=
AttentionMetrics
.
get_parser
().
parse
(
vllm_config
)
assert
attn_result
.
weight_byte_size
==
1
# fp8
ffn_result
=
FfnMetrics
.
get_parser
().
parse
(
vllm_config
)
assert
ffn_result
.
weight_byte_size
==
1
# fp8
def
test_quantization_config_parser_mxfp4
():
"""Test quantization parsers with mxfp4."""
class
MockQuantConfig
:
def
get_name
(
self
):
return
"mxfp4"
hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
intermediate_size
=
8192
,
num_hidden_layers
=
1
)
vllm_config
=
create_mock_vllm_config
(
hf_config
,
quant_config
=
MockQuantConfig
())
ffn_result
=
FfnMetrics
.
get_parser
().
parse
(
vllm_config
)
assert
ffn_result
.
weight_byte_size
==
0.5
# mxfp4
#### Per-GPU Tests ####
def
test_attention_per_gpu_with_tensor_parallelism
():
"""Test attention metrics with tensor parallelism - per_gpu vs global."""
hf_config
=
Qwen3Config
(
hidden_size
=
4096
,
num_attention_heads
=
32
,
num_key_value_heads
=
8
,
num_hidden_layers
=
24
,
)
# Test with TP=4
vllm_config
=
create_mock_vllm_config
(
hf_config
,
tensor_parallel_size
=
4
)
metrics
=
AttentionMetrics
.
from_vllm_config
(
vllm_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
128
,
context_len
=
1024
,
is_prefill
=
True
)
# Get global and per-gpu metrics
global_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
False
)
per_gpu_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
True
)
# With TP=4, global flops should be 4x per-gpu flops (heads divided by 4)
assert
global_flops
==
4
*
per_gpu_flops
# Same for read/write bytes
global_read
=
metrics
.
get_read_bytes
(
ctx
,
per_gpu
=
False
)
per_gpu_read
=
metrics
.
get_read_bytes
(
ctx
,
per_gpu
=
True
)
# Reads should scale similarly (weight reads are divided by TP)
assert
global_read
>
per_gpu_read
global_write
=
metrics
.
get_write_bytes
(
ctx
,
per_gpu
=
False
)
per_gpu_write
=
metrics
.
get_write_bytes
(
ctx
,
per_gpu
=
True
)
assert
global_write
>
per_gpu_write
def
test_attention_per_gpu_with_pipeline_parallelism
():
"""Test attention metrics with pipeline parallelism - per_gpu vs global."""
hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
num_attention_heads
=
16
,
num_hidden_layers
=
32
,
)
# Test with PP=4
vllm_config
=
create_mock_vllm_config
(
hf_config
,
pipeline_parallel_size
=
4
)
metrics
=
AttentionMetrics
.
from_vllm_config
(
vllm_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
100
,
context_len
=
512
,
is_prefill
=
False
)
# Get global and per-gpu metrics
global_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
False
)
per_gpu_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
True
)
# With PP=4, global flops should be 4x per-gpu flops (layers divided by 4)
assert
global_flops
==
4
*
per_gpu_flops
global_read
=
metrics
.
get_read_bytes
(
ctx
,
per_gpu
=
False
)
per_gpu_read
=
metrics
.
get_read_bytes
(
ctx
,
per_gpu
=
True
)
assert
global_read
==
4
*
per_gpu_read
def
test_ffn_per_gpu_with_tensor_parallelism
():
"""Test FFN metrics with tensor parallelism - per_gpu vs global."""
hf_config
=
Qwen3Config
(
hidden_size
=
4096
,
intermediate_size
=
14336
,
num_hidden_layers
=
32
,
)
# Test with DP=2, TP=4 (ffn_tp_size will be 8)
vllm_config
=
create_mock_vllm_config
(
hf_config
,
data_parallel_size
=
2
,
tensor_parallel_size
=
4
,
)
metrics
=
FfnMetrics
.
from_vllm_config
(
vllm_config
)
# ffn_tp_size should be dp_size * tp_size = 8 (when EP not enabled)
assert
metrics
.
ffn_tp_size
==
8
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
128
,
context_len
=
2048
,
is_prefill
=
True
)
# Get global and per-gpu metrics
global_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
False
)
per_gpu_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
True
)
# With ffn_tp_size=8, global should be 8x per-gpu
assert
global_flops
==
8
*
per_gpu_flops
def
test_ffn_per_gpu_with_pipeline_parallelism
():
"""Test FFN metrics with pipeline parallelism - per_gpu vs global."""
hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
intermediate_size
=
8192
,
num_hidden_layers
=
24
,
)
# Test with PP=6
vllm_config
=
create_mock_vllm_config
(
hf_config
,
pipeline_parallel_size
=
6
)
metrics
=
FfnMetrics
.
from_vllm_config
(
vllm_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
100
,
context_len
=
512
,
is_prefill
=
True
)
# Get global and per-gpu metrics
global_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
False
)
per_gpu_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
True
)
# With PP=6, global should be 6x per-gpu (layers divided by 6)
assert
global_flops
==
6
*
per_gpu_flops
def
test_moe_per_gpu_with_expert_parallelism
():
"""
Test MoE metrics with expert parallelism - verifies num_activated_experts bug fix.
"""
hf_config
=
Qwen3MoeConfig
(
hidden_size
=
2048
,
intermediate_size
=
8192
,
num_hidden_layers
=
24
,
num_experts
=
64
,
num_experts_per_tok
=
8
,
moe_intermediate_size
=
14336
,
n_shared_experts
=
2
,
)
# Test with DP=2, TP=4, EP enabled (ffn_ep_size will be 8)
vllm_config
=
create_mock_vllm_config
(
hf_config
,
data_parallel_size
=
2
,
tensor_parallel_size
=
4
,
enable_expert_parallel
=
True
,
)
metrics
=
FfnMetrics
.
from_vllm_config
(
vllm_config
)
# When EP enabled, ffn_ep_size = dp_size * tp_size = 8
assert
metrics
.
ffn_ep_size
==
8
assert
metrics
.
ffn_tp_size
==
1
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
100
,
context_len
=
512
,
is_prefill
=
True
)
# Get per-gpu metrics
per_gpu_read_breakdown
=
metrics
.
get_read_bytes_breakdown
(
ctx
,
per_gpu
=
True
)
global_read_breakdown
=
metrics
.
get_read_bytes_breakdown
(
ctx
,
per_gpu
=
False
)
# Verify that routed expert weight reads are reasonable
# With per_gpu=True, each GPU has 64/8 = 8 experts
# T=100, E_per_gpu=8/8=1, so T*E=100 expert activations
# num_activated_experts should be min(100, 8) = 8
# Check that weight reads scale appropriately
# Global has all 64 experts, per-gpu has 8 experts
# So weight reads should reflect this difference
if
"routed_up_gate_weights"
in
per_gpu_read_breakdown
:
per_gpu_weight_reads
=
per_gpu_read_breakdown
[
"routed_up_gate_weights"
]
global_weight_reads
=
global_read_breakdown
[
"routed_up_gate_weights"
]
# The ratio should reflect the expert count difference
# This verifies the bug fix works correctly
assert
per_gpu_weight_reads
<
global_weight_reads
# Global should read more experts than per-gpu
# Exact ratio depends on num_activated_experts calculation
ratio
=
global_weight_reads
/
per_gpu_weight_reads
# Should be > 1 since global has more experts to read
assert
ratio
>
1
def
test_moe_per_gpu_expert_activation_accounting
():
"""
Test that MoE correctly accounts for expert activations with small batch sizes.
"""
hf_config
=
Qwen3MoeConfig
(
hidden_size
=
2048
,
intermediate_size
=
8192
,
num_hidden_layers
=
12
,
num_experts
=
64
,
num_experts_per_tok
=
8
,
moe_intermediate_size
=
14336
,
n_shared_experts
=
0
,
# No shared experts for this test
)
# Test with EP=8
vllm_config
=
create_mock_vllm_config
(
hf_config
,
data_parallel_size
=
8
,
enable_expert_parallel
=
True
,
)
metrics
=
FfnMetrics
.
from_vllm_config
(
vllm_config
)
# Small batch: T=10, E_per_gpu=8/8=1
# Each GPU: T*E = 10*1 = 10 activations
# Experts per GPU: 64/8 = 8
# So num_activated_experts should be min(10, 8) = 8
small_ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
10
,
context_len
=
512
,
is_prefill
=
True
)
small_read
=
metrics
.
get_read_bytes_breakdown
(
small_ctx
,
per_gpu
=
True
)
# Large batch: T=1000, E_per_gpu=1
# Each GPU: T*E = 1000*1 = 1000 activations
# Experts per GPU: 8
# So num_activated_experts should be min(1000, 8) = 8 (all experts activated)
large_ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
1000
,
context_len
=
512
,
is_prefill
=
True
)
large_read
=
metrics
.
get_read_bytes_breakdown
(
large_ctx
,
per_gpu
=
True
)
# Weight reads should be similar (both activate all 8 experts per GPU)
# But activation reads should differ (proportional to T*E)
if
"routed_up_gate_weights"
in
small_read
:
small_weight
=
small_read
[
"routed_up_gate_weights"
]
large_weight
=
large_read
[
"routed_up_gate_weights"
]
# Weight reads should be the same (both read all 8 experts)
assert
small_weight
==
large_weight
# But input activation reads should scale with T*E
small_input
=
small_read
[
"routed_up_gate_input"
]
large_input
=
large_read
[
"routed_up_gate_input"
]
assert
large_input
==
100
*
small_input
# 1000/10 = 100x
def
test_unembed_per_gpu_with_tensor_parallelism
():
"""Test unembed metrics with tensor parallelism - per_gpu vs global."""
hf_config
=
Qwen3Config
(
hidden_size
=
4096
,
vocab_size
=
128000
,
)
# Test with TP=8
vllm_config
=
create_mock_vllm_config
(
hf_config
,
tensor_parallel_size
=
8
)
metrics
=
UnembedMetrics
.
from_vllm_config
(
vllm_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
100
,
context_len
=
512
,
is_prefill
=
True
)
# Get global and per-gpu metrics
global_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
False
)
per_gpu_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
True
)
# With TP=8, vocab is divided by 8, so global should be 8x per-gpu
assert
global_flops
==
8
*
per_gpu_flops
# For read bytes, weight reads scale with TP but input reads don't (replicated)
global_read_breakdown
=
metrics
.
get_read_bytes_breakdown
(
ctx
,
per_gpu
=
False
)
per_gpu_read_breakdown
=
metrics
.
get_read_bytes_breakdown
(
ctx
,
per_gpu
=
True
)
# Input reads should be the same (replicated across TP ranks)
assert
global_read_breakdown
[
"input"
]
==
per_gpu_read_breakdown
[
"input"
]
# Weight reads should scale 8x (divided by TP)
assert
global_read_breakdown
[
"weight"
]
==
8
*
per_gpu_read_breakdown
[
"weight"
]
def
test_model_metrics_per_gpu_aggregation
():
"""Test ModelMetrics correctly aggregates per_gpu metrics across components."""
hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
num_attention_heads
=
16
,
num_hidden_layers
=
12
,
vocab_size
=
32000
,
intermediate_size
=
8192
,
)
# Test with mixed parallelism: TP=2, PP=2
vllm_config
=
create_mock_vllm_config
(
hf_config
,
tensor_parallel_size
=
2
,
pipeline_parallel_size
=
2
,
)
model_metrics
=
ModelMetrics
(
vllm_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
100
,
context_len
=
512
,
is_prefill
=
True
)
# Get breakdowns for both modes
per_gpu_breakdown
=
model_metrics
.
get_num_flops_breakdown
(
ctx
,
per_gpu
=
True
)
global_breakdown
=
model_metrics
.
get_num_flops_breakdown
(
ctx
,
per_gpu
=
False
)
# Verify breakdown sums match totals
per_gpu_total
=
model_metrics
.
get_num_flops
(
ctx
,
per_gpu
=
True
)
global_total
=
model_metrics
.
get_num_flops
(
ctx
,
per_gpu
=
False
)
assert
per_gpu_total
==
sum
(
per_gpu_breakdown
.
values
())
assert
global_total
==
sum
(
global_breakdown
.
values
())
# Global should be larger than per-gpu due to parallelism
assert
global_total
>
per_gpu_total
# With TP=2 and PP=2, the ratio depends on which parallelism applies to
# which component but we can verify that global is reasonably larger
ratio
=
global_total
/
per_gpu_total
assert
ratio
>
1
# Should be between PP and TP*PP depending on component mix
def
test_attention_per_gpu_heads_not_evenly_divisible
():
"""Test attention with heads not evenly divisible by TP."""
hf_config
=
Qwen3Config
(
hidden_size
=
2048
,
num_attention_heads
=
17
,
# Not divisible by 4
num_key_value_heads
=
5
,
# Not divisible by 4
num_hidden_layers
=
8
,
)
vllm_config
=
create_mock_vllm_config
(
hf_config
,
tensor_parallel_size
=
4
)
metrics
=
AttentionMetrics
.
from_vllm_config
(
vllm_config
)
ctx
=
ExecutionContext
.
from_single_request
(
num_tokens
=
64
,
context_len
=
256
,
is_prefill
=
True
)
# Should not crash and should handle max(1, ...) correctly
per_gpu_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
True
)
global_flops
=
metrics
.
get_num_flops
(
ctx
,
per_gpu
=
False
)
# Both should be positive
assert
per_gpu_flops
>
0
assert
global_flops
>
0
assert
global_flops
>
per_gpu_flops
tests/v1/sample/test_logprobs.py
View file @
7e63ef82
...
...
@@ -516,6 +516,424 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
del
llm
class
TestCorrectDecodedToken
:
"""Unit tests for _correct_decoded_token method in LogprobsProcessor.
This method handles UTF-8 decoding issues where incomplete byte sequences
result in the Unicode replacement character "�" (U+FFFD). This commonly
happens with byte-fallback tokenization when multi-byte UTF-8 characters
are split across tokens.
"""
@
pytest
.
fixture
def
mock_tokenizer
(
self
):
"""Create a mock tokenizer for testing."""
from
unittest.mock
import
Mock
tokenizer
=
Mock
()
return
tokenizer
@
pytest
.
fixture
def
processor_with_empty_logprobs
(
self
,
mock_tokenizer
):
"""Create a LogprobsProcessor with empty logprobs."""
from
vllm.v1.engine.logprobs
import
LogprobsProcessor
processor
=
LogprobsProcessor
(
tokenizer
=
mock_tokenizer
,
logprobs
=
[],
prompt_logprobs
=
None
,
cumulative_logprob
=
0.0
,
num_logprobs
=
1
,
num_prompt_logprobs
=
None
,
)
return
processor
@
pytest
.
fixture
def
processor_with_previous_logprobs
(
self
,
mock_tokenizer
):
"""Create a LogprobsProcessor with previous logprobs."""
from
vllm.v1.engine.logprobs
import
LogprobsProcessor
processor
=
LogprobsProcessor
(
tokenizer
=
mock_tokenizer
,
logprobs
=
[{
123
:
None
}],
# Previous token ID is 123
prompt_logprobs
=
None
,
cumulative_logprob
=
0.0
,
num_logprobs
=
1
,
num_prompt_logprobs
=
None
,
)
return
processor
def
test_correction_with_previous_token_in_list
(
self
,
processor_with_empty_logprobs
):
"""Test correction using previous token in the same list.
Scenario: Token at idx=1 ends with "�", but when decoded with
the previous token (idx=0), it forms a valid UTF-8 sequence.
Example: token[0]="�", token[1]="�" -> together form "polarized"
"""
processor
=
processor_with_empty_logprobs
tokens
=
[
100
,
101
,
102
]
# token IDs
# Mock tokenizer behavior:
# - decode([102]) returns "�" (ends with replacement char)
# - decode([101, 102]) returns "valid" (no replacement char)
processor
.
tokenizer
.
decode
.
side_effect
=
lambda
ids
:
(
"valid"
if
ids
==
[
101
,
102
]
else
"�"
)
result
=
processor
.
_correct_decoded_token
(
2
,
tokens
)
assert
result
==
"valid"
processor
.
tokenizer
.
decode
.
assert_called_with
([
101
,
102
])
def
test_correction_with_previous_logprob_token
(
self
,
processor_with_previous_logprobs
):
"""Test correction using previous logprob token.
Scenario: Cannot correct with previous token in list (idx=0),
but can correct with previous logprob token.
"""
processor
=
processor_with_previous_logprobs
tokens
=
[
100
]
# single token
# Mock tokenizer behavior:
# - decode([100]) returns "�" (ends with replacement char)
# - decode([123, 100]) returns " "polarized" (no replacement char)
# Token 123 is from previous logprobs
def
mock_decode
(
ids
):
if
ids
==
[
123
,
100
]:
return
' "polarized"'
return
"�"
processor
.
tokenizer
.
decode
.
side_effect
=
mock_decode
result
=
processor
.
_correct_decoded_token
(
0
,
tokens
)
assert
result
==
' "polarized"'
def
test_correction_at_idx_zero_no_previous_logprobs
(
self
,
processor_with_empty_logprobs
):
"""Test correction at idx=0 with no previous logprobs.
Scenario: First token in list, no previous logprobs available.
Should return empty string as fallback.
"""
processor
=
processor_with_empty_logprobs
tokens
=
[
100
]
# Mock tokenizer always returns "�"
processor
.
tokenizer
.
decode
.
return_value
=
"�"
result
=
processor
.
_correct_decoded_token
(
0
,
tokens
)
assert
result
==
""
def
test_correction_at_idx_zero_with_previous_logprobs
(
self
,
processor_with_previous_logprobs
):
"""Test correction at idx=0 with previous logprobs available.
Scenario: First token in list, but previous logprobs exist.
Should try correction with previous logprob token.
"""
processor
=
processor_with_previous_logprobs
tokens
=
[
200
]
# Mock tokenizer behavior
def
mock_decode
(
ids
):
if
ids
==
[
123
,
200
]:
return
"corrected"
return
"�"
processor
.
tokenizer
.
decode
.
side_effect
=
mock_decode
result
=
processor
.
_correct_decoded_token
(
0
,
tokens
)
assert
result
==
"corrected"
def
test_no_correction_needed_returns_fallback
(
self
,
processor_with_previous_logprobs
):
"""Test fallback to empty string when no correction works.
Scenario: All correction attempts still end with "�".
Should return empty string as final fallback.
"""
processor
=
processor_with_previous_logprobs
tokens
=
[
100
,
101
,
102
]
# Mock tokenizer always returns text ending with "�"
processor
.
tokenizer
.
decode
.
return_value
=
"still�"
result
=
processor
.
_correct_decoded_token
(
2
,
tokens
)
assert
result
==
""
def
test_middle_token_correction
(
self
,
processor_with_previous_logprobs
):
"""Test correction for a token in the middle of the list.
Scenario: Token at idx=5 in a longer list needs correction.
"""
processor
=
processor_with_previous_logprobs
tokens
=
[
10
,
20
,
30
,
40
,
50
,
60
,
70
,
80
]
# Mock tokenizer behavior for middle token
def
mock_decode
(
ids
):
if
ids
==
[
50
,
60
]:
return
"olar"
return
"�"
processor
.
tokenizer
.
decode
.
side_effect
=
mock_decode
result
=
processor
.
_correct_decoded_token
(
5
,
tokens
)
assert
result
==
"olar"
def
test_multiple_consecutive_replacement_chars
(
self
,
processor_with_previous_logprobs
):
"""Test handling of multiple consecutive replacement characters.
Scenario: Sequence like ["�", "�", "p"] where first two should
become empty strings.
"""
processor
=
processor_with_previous_logprobs
# Test first replacement char
tokens
=
[
100
,
101
,
102
]
processor
.
tokenizer
.
decode
.
return_value
=
"still�"
result1
=
processor
.
_correct_decoded_token
(
0
,
tokens
)
assert
result1
==
""
# Test second replacement char
result2
=
processor
.
_correct_decoded_token
(
1
,
tokens
)
assert
result2
==
""
def
test_correction_with_multibyte_utf8
(
self
,
processor_with_previous_logprobs
):
"""Test correction involving multi-byte UTF-8 characters.
Scenario: Byte-fallback tokenization splits multi-byte UTF-8
characters (e.g., curly quotes, Chinese characters, emojis).
Example from user: "�", "�" -> "", "
\"
"
"""
processor
=
processor_with_previous_logprobs
tokens
=
[
200
,
201
]
# Mock tokenizer behavior for multi-byte UTF-8 correction
def
mock_decode
(
ids
):
# When decoding first token (idx=0) with previous logprob token
if
ids
==
[
123
,
200
]:
return
' "'
# Space + left curly quote
# When decoding second token (idx=1) with previous token in list
elif
ids
==
[
200
,
201
]:
return
'"'
# Right curly quote
# When decoding second token (idx=1) with previous logprob + prev token
elif
ids
==
[
123
,
200
,
201
]:
return
' ""'
# Full sequence
return
"�"
processor
.
tokenizer
.
decode
.
side_effect
=
mock_decode
# First token correction (idx=0)
# Will call decode([123, 200]) since idx=0 uses previous logprob token
result1
=
processor
.
_correct_decoded_token
(
0
,
tokens
)
assert
result1
==
' "'
# Second token correction (idx=1)
# Will call decode([200, 201]) since idx>0 uses previous token in list
result2
=
processor
.
_correct_decoded_token
(
1
,
tokens
)
assert
result2
==
'"'
def
test_real_world_opt125m_scenario
(
self
,
mock_tokenizer
):
"""Test the real-world scenario from user's example.
User's example with facebook/opt-125m:
Before: [" the", " term", " �", "�", "p", "olar", "ized", "�", "�", ...]
After: [" the", " term", "", " "", "p", "olar", "ized", "", "
\"
", ...]
"""
from
vllm.v1.engine.logprobs
import
LogprobsProcessor
# Simulate the sequence of tokens
processor
=
LogprobsProcessor
(
tokenizer
=
mock_tokenizer
,
logprobs
=
[],
prompt_logprobs
=
None
,
cumulative_logprob
=
0.0
,
num_logprobs
=
1
,
num_prompt_logprobs
=
None
,
)
# Token IDs representing the problematic sequence
tokens
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
]
# placeholder IDs
# Mock decode behavior simulating the real scenario
def
mock_decode
(
ids
):
# Simulate cases where individual tokens decode to "�"
# but combinations decode correctly
if
len
(
ids
)
==
1
:
if
ids
[
0
]
==
3
or
ids
[
0
]
==
4
or
ids
[
0
]
==
8
or
ids
[
0
]
==
9
:
return
"�"
elif
len
(
ids
)
==
2
:
if
ids
==
[
2
,
3
]:
return
" term�"
# Still ends with �, need more context
elif
ids
==
[
3
,
4
]:
return
' "'
# Corrected to space + left curly quote
elif
ids
==
[
7
,
8
]:
return
"ized�"
# Still ends with �
elif
ids
==
[
8
,
9
]:
return
'"'
# Corrected to right curly quote
elif
len
(
ids
)
==
3
:
if
ids
==
[
1
,
2
,
3
]:
return
" the term�"
# Still ends with issue
elif
ids
==
[
2
,
3
,
4
]:
return
' term "'
# With all context
return
"normal_text"
mock_tokenizer
.
decode
.
side_effect
=
mock_decode
# Test token at index 2 (should fail to correct, return "")
# Token 3 individually is "�"
# decode([2, 3]) = " term�" (still ends with �)
# No previous logprobs, so fallback to ""
result
=
processor
.
_correct_decoded_token
(
2
,
tokens
)
assert
result
==
""
# Test token at index 3 (should correct to " "")
# Token 4 individually is "�"
# decode([3, 4]) = " "" (corrected!)
processor
.
logprobs
=
[{
2
:
None
}]
# Add previous logprob
result
=
processor
.
_correct_decoded_token
(
3
,
tokens
)
assert
result
==
' "'
def
test_verify_tokens_integration
():
"""Integration test for _verify_tokens with real model.
This test validates that _verify_tokens correctly identifies and
corrects tokens ending with the replacement character "�".
Uses facebook/opt-125m which is known to produce these issues.
"""
runner
=
VllmRunner
(
"facebook/opt-125m"
,
max_logprobs
=
0
,
enable_prefix_caching
=
False
,
gpu_memory_utilization
=
0.15
,
max_model_len
=
256
,
)
# Use a prompt that triggers multi-byte UTF-8 issues
# Based on user's example: "In this example,"
test_prompts
=
[
"In this example,"
]
sampling_params
=
SamplingParams
(
max_tokens
=
16
,
temperature
=
0
,
logprobs
=
0
,
)
results
=
runner
.
llm
.
generate
(
test_prompts
,
sampling_params
=
sampling_params
)
# Verify that decoded tokens don't contain replacement characters
for
result
in
results
:
assert
result
.
outputs
[
0
].
logprobs
is
not
None
for
logprob_dict
in
result
.
outputs
[
0
].
logprobs
:
for
token_id
,
logprob_info
in
logprob_dict
.
items
():
decoded_token
=
logprob_info
.
decoded_token
# Decoded tokens should not end with replacement character
# They should either be corrected or empty string
assert
not
decoded_token
.
endswith
(
"�"
),
(
f
"Token
{
token_id
}
decoded to '
{
decoded_token
}
' which "
f
"ends with replacement character"
)
# Decoded tokens should not contain lone replacement characters
assert
decoded_token
!=
"�"
,
(
f
"Token
{
token_id
}
is a lone replacement character"
)
def
test_utf8_edge_cases_with_real_model
():
"""Test various UTF-8 edge cases with a real model.
Tests prompts that are likely to trigger byte-fallback tokenization
and multi-byte UTF-8 splitting.
"""
runner
=
VllmRunner
(
"facebook/opt-125m"
,
max_logprobs
=
1
,
enable_prefix_caching
=
False
,
gpu_memory_utilization
=
0.15
,
max_model_len
=
256
,
)
# Prompts with various multi-byte UTF-8 characters
test_prompts
=
[
'Smart quotes: "Hello"'
,
# Curly quotes
"Em dash — test"
,
# Em dash
"Ellipsis… continues"
,
# Ellipsis
"Chinese: 你好"
,
# Chinese characters
"Emoji: 😀 🎉"
,
# Emojis
'Mixed: "quoted" — with symbols'
,
# Mixed
]
sampling_params
=
SamplingParams
(
max_tokens
=
10
,
temperature
=
0
,
logprobs
=
1
,
)
results
=
runner
.
llm
.
generate
(
test_prompts
,
sampling_params
=
sampling_params
)
for
i
,
result
in
enumerate
(
results
):
prompt
=
test_prompts
[
i
]
assert
result
.
outputs
[
0
].
logprobs
is
not
None
# Check that no decoded tokens end with replacement character
for
logprob_dict
in
result
.
outputs
[
0
].
logprobs
:
for
token_id
,
logprob_info
in
logprob_dict
.
items
():
decoded_token
=
logprob_info
.
decoded_token
assert
not
decoded_token
.
endswith
(
"�"
),
(
f
"Prompt: '
{
prompt
}
'
\n
"
f
"Token
{
token_id
}
decoded to '
{
decoded_token
}
' which "
f
"ends with replacement character"
)
def
test_correct_decoded_token_preserves_valid_tokens
():
"""Test that valid tokens (not ending with �) are not modified.
The _correct_decoded_token method should only be called for tokens
ending with "�", but this test verifies the broader _verify_tokens
logic doesn't affect valid tokens.
"""
runner
=
VllmRunner
(
"facebook/opt-125m"
,
max_logprobs
=
2
,
enable_prefix_caching
=
False
,
gpu_memory_utilization
=
0.15
,
max_model_len
=
256
,
)
# Simple prompt with standard ASCII characters
test_prompts
=
[
"Hello world, this is a test."
]
sampling_params
=
SamplingParams
(
max_tokens
=
10
,
temperature
=
0
,
logprobs
=
2
,
)
results
=
runner
.
llm
.
generate
(
test_prompts
,
sampling_params
=
sampling_params
)
for
result
in
results
:
assert
result
.
outputs
[
0
].
logprobs
is
not
None
# All decoded tokens should be valid strings
for
logprob_dict
in
result
.
outputs
[
0
].
logprobs
:
for
token_id
,
logprob_info
in
logprob_dict
.
items
():
decoded_token
=
logprob_info
.
decoded_token
# Valid tokens should be non-empty strings (or empty if corrected)
assert
isinstance
(
decoded_token
,
str
)
# Should not contain replacement character
assert
"�"
not
in
decoded_token
@
pytest
.
mark
.
parametrize
(
"logprobs_mode"
,
get_args
(
LogprobsMode
))
@
pytest
.
mark
.
parametrize
(
"model_setup"
,
...
...
@@ -524,32 +942,74 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
(
"eagle"
,
"meta-llama/Llama-3.2-1B-Instruct"
,
"nm-testing/Llama3_2_1B_speculator.eagle3"
,
{
"method"
:
"eagle"
,
"model"
:
"nm-testing/Llama3_2_1B_speculator.eagle3"
,
"num_speculative_tokens"
:
3
,
},
0
,
),
marks
=
large_gpu_mark
(
min_gb
=
32
),
id
=
"eagle0"
,
),
pytest
.
param
(
(
"eagle"
,
"meta-llama/Llama-3.2-1B-Instruct"
,
{
"method"
:
"eagle"
,
"model"
:
"nm-testing/Llama3_2_1B_speculator.eagle3"
,
"num_speculative_tokens"
:
3
,
},
3
,
),
marks
=
large_gpu_mark
(
min_gb
=
32
),
id
=
"eagle3"
,
),
pytest
.
param
(
(
"ngram"
,
"meta-llama/Llama-3.2-1B-Instruct"
,
{
"method"
:
"ngram"
,
"prompt_lookup_max"
:
5
,
"prompt_lookup_min"
:
3
,
"num_speculative_tokens"
:
3
,
},
3
,
),
marks
=
large_gpu_mark
(
min_gb
=
32
),
id
=
"ngram"
,
),
],
)
@
pytest
.
mark
.
parametrize
(
"top_logprobs"
,
[
0
,
3
])
def
test_spec_decode_logprobs
(
logprobs_mode
:
LogprobsMode
,
model_setup
:
tuple
[
str
,
str
,
str
],
top_logprobs
:
int
,
model_setup
:
tuple
[
str
,
str
,
dict
,
int
],
):
"""Spec decode logprobs should match those of the base model.
Args:
logprobs_mode: logprobs mode.
model_setup:
Spec decode
method, base model name,
and
draft model name
.
model_setup:
Tuple of (
method, base model name,
speculative_config dict, top_logprobs)
.
"""
from
vllm
import
LLM
method
,
model_name
,
spec_config
,
top_logprobs
=
model_setup
prompt
=
"Hello world "
*
50
sampling_params
=
SamplingParams
(
temperature
=
0
,
logprobs
=
top_logprobs
,
max_tokens
=
10
,
ignore_eos
=
False
)
method
,
model_name
,
spec_model_name
=
model_setup
penalty_sampling_params
=
SamplingParams
(
temperature
=
0
,
logprobs
=
top_logprobs
,
max_tokens
=
10
,
ignore_eos
=
False
,
presence_penalty
=-
1.0
,
)
max_model_len
=
256
# Run base LLM.
...
...
@@ -560,27 +1020,27 @@ def test_spec_decode_logprobs(
seed
=
42
,
logprobs_mode
=
logprobs_mode
,
gpu_memory_utilization
=
0.4
,
enable_prefix_caching
=
False
,
)
ref_results
=
ref_llm
.
generate
(
[
prompt
,
prompt
],
[
sampling_params
,
penalty_sampling_params
]
)
ref_results
=
ref_llm
.
generate
([
prompt
],
sampling_params
)
# Collect logprobs outputs from reference LLM.
ref_logprobs
=
[]
for
output
in
ref_results
[
0
].
outputs
:
for
logprobs
in
output
.
logprob
s
:
for
token_id
in
logprobs
:
ref_logprobs
.
app
end
(
logprobs
[
token_id
]
)
for
results
in
ref_results
:
for
output
in
results
.
output
s
:
for
logprobs
in
output
.
logprobs
:
ref_logprobs
.
ext
end
(
logprobs
.
values
()
)
del
ref_llm
torch
.
cuda
.
empty_cache
()
cleanup_dist_env_and_memory
()
# Run spec decode LLM.
# Add max_model_len to spec_config if not present
spec_config_with_len
=
{
**
spec_config
,
"max_model_len"
:
max_model_len
}
spec_llm
=
LLM
(
model_name
,
speculative_config
=
{
"method"
:
method
,
"model"
:
spec_model_name
,
"num_speculative_tokens"
:
3
,
"max_model_len"
:
max_model_len
,
},
speculative_config
=
spec_config_with_len
,
max_logprobs
=
5
,
max_model_len
=
max_model_len
,
seed
=
42
,
...
...
@@ -589,14 +1049,17 @@ def test_spec_decode_logprobs(
# Force prefill chunking
enable_chunked_prefill
=
True
,
max_num_batched_tokens
=
32
,
enable_prefix_caching
=
False
,
)
spec_results
=
spec_llm
.
generate
(
[
prompt
,
prompt
],
[
sampling_params
,
penalty_sampling_params
]
)
spec_results
=
spec_llm
.
generate
([
prompt
],
sampling_params
)
# Collect logprobs outputs from spec decode LLM.
spec_logprobs
=
[]
for
output
in
spec_results
[
0
].
outputs
:
for
logprobs
in
output
.
logprob
s
:
for
token_id
in
logprobs
:
spec_logprobs
.
app
end
(
logprobs
[
token_id
]
)
for
results
in
spec_results
:
for
output
in
results
.
output
s
:
for
logprobs
in
output
.
logprobs
:
spec_logprobs
.
ext
end
(
logprobs
.
values
()
)
del
spec_llm
torch
.
cuda
.
empty_cache
()
cleanup_dist_env_and_memory
()
...
...
tests/v1/sample/test_rejection_sampler.py
View file @
7e63ef82
...
...
@@ -691,9 +691,13 @@ def test_frequency_penalties(rejection_sampler):
def
test_bad_words
(
rejection_sampler
):
"""Test rejection sampling with bad words constraints"""
"""Test rejection sampling with bad words constraints.
This test applies bad words to non-consecutive requests (0 and 2, but not 1)
to verify correct logit indexing when iterating over requests with bad words.
"""
spec_tokens
=
[[
1
,
2
,
3
],
[
1
,
15
,
3
],
[
1
,
2
,
3
]]
output_tokens
=
[[
1
,
2
,
3
,
4
],
[
1
,
2
,
3
,
4
],
[
1
,
2
,
3
,
4
]]
output_tokens
=
[[
1
,
2
,
3
,
4
],
[
1
,
15
,
3
,
4
],
[
1
,
2
,
3
,
4
]]
logits
=
create_logits_tensor
(
output_tokens
,
token_idx_to_override
=
15
)
metadata
=
create_sampling_metadata
(
...
...
@@ -701,17 +705,9 @@ def test_bad_words(rejection_sampler):
output_token_ids
=
[[
2
],
[
3
],
[
4
]],
spec_token_ids
=
spec_tokens
,
bad_words_token_ids
=
{
0
:
[
[
2
,
]
],
1
:
[
[
2
,
]
],
# Do not apply bad words to the last request
0
:
[[
2
]],
# Request 1 has no bad words (to test non-consecutive request handling)
2
:
[[
2
]],
},
)
bonus_token_tensor
=
torch
.
tensor
(
...
...
@@ -726,8 +722,11 @@ def test_bad_words(rejection_sampler):
sampling_metadata
=
metadata
,
)
# Request 0: bad word [2] matches prefix, so token 2 is rejected -> 15
# Request 1: no bad words, all tokens match -> [1, 15, 3, 4]
# Request 2: bad word [2] matches prefix, so token 2 is rejected -> 15
expected
=
torch
.
tensor
(
[[
1
,
15
,
-
1
,
-
1
],
[
1
,
15
,
3
,
4
],
[
1
,
2
,
3
,
4
]],
[[
1
,
15
,
-
1
,
-
1
],
[
1
,
15
,
3
,
4
],
[
1
,
15
,
-
1
,
-
1
]],
dtype
=
torch
.
int
,
device
=
logits
.
device
,
)
...
...
tests/v1/spec_decode/test_eagle.py
View file @
7e63ef82
...
...
@@ -14,8 +14,8 @@ from tests.v1.attention.utils import (
create_standard_kv_cache_spec
,
try_get_attention_backend
,
)
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.config
import
(
AttentionConfig
,
CacheConfig
,
DeviceConfig
,
ModelConfig
,
...
...
@@ -27,6 +27,7 @@ from vllm.config import (
from
vllm.config.load
import
LoadConfig
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
...
...
@@ -41,6 +42,7 @@ eagle3_dir = os.path.join(models_path_prefix, "yuhuili/EAGLE3-LLaMA3.1-Instruct-
def
_create_proposer
(
method
:
str
,
num_speculative_tokens
:
int
,
attention_backend
:
str
|
None
=
None
,
speculative_token_tree
:
list
[
tuple
[
int
,
...]]
|
None
=
None
,
)
->
EagleProposer
:
model_config
=
ModelConfig
(
model
=
model_dir
,
runner
=
"generate"
,
max_model_len
=
100
)
...
...
@@ -73,6 +75,7 @@ def _create_proposer(
max_model_len
=
model_config
.
max_model_len
,
is_encoder_decoder
=
model_config
.
is_encoder_decoder
,
),
attention_config
=
AttentionConfig
(
backend
=
attention_backend
),
)
return
EagleProposer
(
vllm_config
=
vllm_config
,
device
=
current_platform
.
device_type
)
...
...
@@ -306,10 +309,16 @@ def test_prepare_inputs_padded():
proposer
=
_create_proposer
(
"eagle"
,
num_speculative_tokens
)
output_metadata
,
token_indices_to_sample
=
proposer
.
prepare_inputs_padded
(
common_attn_metadata
,
spec_decode_metadata
,
valid_sampled_tokens_count
output_metadata
,
token_indices_to_sample
,
num_rejected_tokens_gpu
=
(
proposer
.
prepare_inputs_padded
(
common_attn_metadata
,
spec_decode_metadata
,
valid_sampled_tokens_count
)
)
# Verify num_rejected_tokens_gpu is calculated correctly
expected_num_rejected
=
torch
.
tensor
([
1
,
0
,
2
],
dtype
=
torch
.
int32
,
device
=
device
)
assert
torch
.
equal
(
num_rejected_tokens_gpu
,
expected_num_rejected
)
assert
output_metadata
.
max_query_len
==
3
assert
torch
.
equal
(
output_metadata
.
query_start_loc
,
expected_query_start_loc
)
assert
torch
.
equal
(
token_indices_to_sample
,
expected_token_indices_to_sample
)
...
...
@@ -334,8 +343,6 @@ def test_load_model(
use_distinct_lm_head
,
monkeypatch
,
):
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
if
attn_backend
==
"TRITON_ATTN"
and
not
current_platform
.
is_rocm
():
pytest
.
skip
(
"TRITON_ATTN does not support "
...
...
@@ -399,7 +406,9 @@ def test_load_model(
assert
not
isinstance
(
target_model
,
SupportsMultiModal
)
# Create proposer using the helper function
proposer
=
_create_proposer
(
method
,
num_speculative_tokens
=
8
)
proposer
=
_create_proposer
(
method
,
num_speculative_tokens
=
8
,
attention_backend
=
attn_backend
)
# Call the method under test
proposer
.
load_model
(
target_model
)
...
...
@@ -425,8 +434,6 @@ def test_load_model(
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
get_attn_backend_list_based_on_platform
())
@
pytest
.
mark
.
parametrize
(
"num_speculative_tokens"
,
[
1
,
3
,
8
])
def
test_propose
(
method
,
attn_backend
,
num_speculative_tokens
,
monkeypatch
):
monkeypatch
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
if
attn_backend
==
"TRITON_ATTN"
and
not
current_platform
.
is_rocm
():
pytest
.
skip
(
"TRITON_ATTN does not support "
...
...
@@ -454,7 +461,9 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
seq_lens
=
[
seq_len_1
,
seq_len_2
]
# Create proposer first so we can use its actual hidden_size
proposer
=
_create_proposer
(
"eagle"
,
num_speculative_tokens
)
proposer
=
_create_proposer
(
"eagle"
,
num_speculative_tokens
,
attention_backend
=
attn_backend
)
# Get the hidden_size from the proposer to ensure consistency
hidden_size
=
proposer
.
hidden_size
...
...
@@ -627,7 +636,9 @@ def test_propose_tree(spec_token_tree):
# Create proposer first so we can use its actual hidden_size.
proposer
=
_create_proposer
(
"eagle"
,
num_speculative_tokens
,
speculative_token_tree
=
spec_token_tree
"eagle"
,
num_speculative_tokens
,
speculative_token_tree
=
spec_token_tree
,
)
# Get the hidden_size from the proposer to ensure consistency.
hidden_size
=
proposer
.
hidden_size
...
...
tests/v1/spec_decode/test_mtp.py
View file @
7e63ef82
...
...
@@ -13,7 +13,6 @@ from tests.v1.attention.utils import (
create_standard_kv_cache_spec
,
try_get_attention_backend
,
)
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
...
...
@@ -26,6 +25,7 @@ from vllm.config import (
from
vllm.config.load
import
LoadConfig
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
...utils
import
models_path_prefix
...
...
tests/v1/spec_decode/test_ngram.py
View file @
7e63ef82
...
...
@@ -85,10 +85,8 @@ def test_ngram_proposer():
token_ids_cpu
=
np
.
array
([[
1
,
2
,
3
,
4
,
5
]])
result
=
get_ngram_proposer
(
min_n
=
2
,
max_n
=
2
,
k
=
2
).
propose
(
sampled_token_ids
=
[[
0
]],
req_ids
=
[
"0"
],
num_tokens_no_spec
=
np
.
array
([
len
(
c
)
for
c
in
token_ids_cpu
]),
token_ids_cpu
=
token_ids_cpu
,
spec_decode_unsupported_reqs
=
(),
)
assert
len
(
result
[
0
])
==
0
...
...
@@ -96,10 +94,8 @@ def test_ngram_proposer():
token_ids_cpu
=
np
.
array
([[
1
,
2
,
3
,
4
,
1
,
2
,
3
]])
result
=
get_ngram_proposer
(
min_n
=
4
,
max_n
=
4
,
k
=
2
).
propose
(
sampled_token_ids
=
[[
0
]],
req_ids
=
[
"0"
],
num_tokens_no_spec
=
np
.
array
([
len
(
c
)
for
c
in
token_ids_cpu
]),
token_ids_cpu
=
token_ids_cpu
,
spec_decode_unsupported_reqs
=
(),
)
assert
len
(
result
[
0
])
==
0
...
...
@@ -107,10 +103,8 @@ def test_ngram_proposer():
token_ids_cpu
=
np
.
array
([[
1
,
2
,
3
,
4
,
1
,
2
,
3
]])
result
=
get_ngram_proposer
(
min_n
=
3
,
max_n
=
4
,
k
=
2
).
propose
(
sampled_token_ids
=
[[
0
]],
req_ids
=
[
"0"
],
num_tokens_no_spec
=
np
.
array
([
len
(
c
)
for
c
in
token_ids_cpu
]),
token_ids_cpu
=
token_ids_cpu
,
spec_decode_unsupported_reqs
=
(),
)
assert
np
.
array_equal
(
result
,
np
.
array
([[
4
,
1
]]))
...
...
@@ -119,10 +113,8 @@ def test_ngram_proposer():
token_ids_cpu
=
np
.
array
([[
2
,
3
,
4
,
5
,
1
,
2
,
3
,
4
,
1
,
2
,
3
,
4
]])
result
=
get_ngram_proposer
(
min_n
=
3
,
max_n
=
4
,
k
=
2
).
propose
(
sampled_token_ids
=
[[
0
]],
req_ids
=
[
"0"
],
num_tokens_no_spec
=
np
.
array
([
len
(
c
)
for
c
in
token_ids_cpu
]),
token_ids_cpu
=
token_ids_cpu
,
spec_decode_unsupported_reqs
=
(),
)
assert
np
.
array_equal
(
result
,
np
.
array
([[
1
,
2
]]))
# Not [5, 1]]
...
...
@@ -130,10 +122,8 @@ def test_ngram_proposer():
token_ids_cpu
=
np
.
array
([[
3
,
4
,
5
,
2
,
3
,
4
,
1
,
2
,
3
,
4
]])
result
=
get_ngram_proposer
(
min_n
=
2
,
max_n
=
4
,
k
=
2
).
propose
(
sampled_token_ids
=
[[
0
]],
req_ids
=
[
"0"
],
num_tokens_no_spec
=
np
.
array
([
len
(
c
)
for
c
in
token_ids_cpu
]),
token_ids_cpu
=
token_ids_cpu
,
spec_decode_unsupported_reqs
=
(),
)
assert
np
.
array_equal
(
result
,
np
.
array
([[
1
,
2
]]))
# Not [5, 2]]
...
...
@@ -141,10 +131,8 @@ def test_ngram_proposer():
token_ids_cpu
=
np
.
array
([[
1
,
2
,
3
,
100
,
1
,
2
,
3
,
200
,
1
,
2
,
3
,
300
,
1
,
2
,
3
]])
result
=
get_ngram_proposer
(
min_n
=
3
,
max_n
=
3
,
k
=
2
).
propose
(
sampled_token_ids
=
[[
0
]],
req_ids
=
[
"0"
],
num_tokens_no_spec
=
np
.
array
([
len
(
c
)
for
c
in
token_ids_cpu
]),
token_ids_cpu
=
token_ids_cpu
,
spec_decode_unsupported_reqs
=
(),
)
assert
np
.
array_equal
(
result
,
np
.
array
([[
100
,
1
]]))
...
...
@@ -152,10 +140,8 @@ def test_ngram_proposer():
token_ids_cpu
=
np
.
array
([[]])
result
=
get_ngram_proposer
(
min_n
=
2
,
max_n
=
2
,
k
=
2
).
propose
(
sampled_token_ids
=
[[
0
]],
req_ids
=
[
"0"
],
num_tokens_no_spec
=
np
.
array
([
len
(
c
)
for
c
in
token_ids_cpu
]),
token_ids_cpu
=
token_ids_cpu
,
spec_decode_unsupported_reqs
=
(),
)
assert
len
(
result
[
0
])
==
0
...
...
@@ -165,10 +151,8 @@ def test_ngram_proposer():
token_ids_cpu
=
np
.
array
([[
1
,
2
,
3
,
1
,
2
],
[
4
,
5
,
6
,
-
1
,
-
1
]])
result
=
get_ngram_proposer
(
min_n
=
2
,
max_n
=
2
,
k
=
2
).
propose
(
sampled_token_ids
=
[[
0
],
[
1
]],
req_ids
=
[
"0"
,
"1"
],
num_tokens_no_spec
=
np
.
array
([
5
,
3
]),
token_ids_cpu
=
token_ids_cpu
,
spec_decode_unsupported_reqs
=
(),
)
assert
len
(
result
[
0
])
==
2
assert
np
.
array_equal
(
result
[
0
],
np
.
array
([
3
,
1
]))
...
...
@@ -186,10 +170,8 @@ def test_ngram_proposer():
sampled_token_ids
=
[[
2
],
[],
[
8
]]
# Empty list for request 1 simulates prefill
result
=
proposer
.
propose
(
sampled_token_ids
=
sampled_token_ids
,
req_ids
=
[
"0"
,
"1"
,
"2"
],
num_tokens_no_spec
=
num_tokens_no_spec
,
token_ids_cpu
=
token_ids_cpu
,
spec_decode_unsupported_reqs
=
(),
)
assert
len
(
result
)
==
3
assert
np
.
array_equal
(
result
[
0
],
[
3
,
1
])
...
...
@@ -217,10 +199,8 @@ def test_ngram_proposer():
token_ids_cpu
=
np
.
array
([
input_1
,
input_2
])
result
=
ngram_proposer
.
propose
(
sampled_token_ids
=
[[
0
],
[
1
]],
req_ids
=
[
"0"
,
"1"
],
num_tokens_no_spec
=
np
.
array
([
len
(
input_1
),
3
]),
token_ids_cpu
=
token_ids_cpu
,
spec_decode_unsupported_reqs
=
(),
)
assert
len
(
result
[
0
])
==
2
assert
np
.
array_equal
(
result
[
0
],
np
.
array
([
middle_integer
+
2
,
middle_integer
+
3
]))
...
...
tests/v1/spec_decode/test_tree_attention.py
View file @
7e63ef82
...
...
@@ -11,10 +11,10 @@ from tests.v1.attention.utils import (
create_vllm_config
,
try_get_attention_backend
,
)
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.attention.utils.fa_utils
import
is_flash_attn_varlen_func_available
from
vllm.config
import
ParallelConfig
,
SpeculativeConfig
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backend
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.fa_utils
import
is_flash_attn_varlen_func_available
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
if
not
is_flash_attn_varlen_func_available
():
pytest
.
skip
(
...
...
tests/v1/spec_decode/untest_max_len.py
View file @
7e63ef82
...
...
@@ -38,53 +38,48 @@ def test_ngram_max_len(num_speculative_tokens: int):
def
test_eagle_max_len
(
monkeypatch
:
pytest
.
MonkeyPatch
,
num_speculative_tokens
:
int
,
attn_backend
:
str
):
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
if
attn_backend
==
"TRITON_ATTN"
and
not
current_platform
.
is_rocm
():
pytest
.
skip
(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if
attn_backend
==
"TRITON_ATTN"
and
not
current_platform
.
is_rocm
():
pytest
.
skip
(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if
attn_backend
==
"ROCM_AITER_FA"
and
current_platform
.
is_rocm
():
m
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
if
attn_backend
==
"ROCM_AITER_FA"
and
current_platform
.
is_rocm
():
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
llm
=
LLM
(
model
=
"meta-llama/Meta-Llama-3-8B-Instruct"
,
enforce_eager
=
True
,
# For faster initialization.
speculative_config
=
{
"method"
:
"eagle"
,
"model"
:
"yuhuili/EAGLE-LLaMA3-Instruct-8B"
,
"num_speculative_tokens"
:
num_speculative_tokens
,
"max_model_len"
:
80
,
},
max_model_len
=
200
,
llm
=
LLM
(
model
=
"meta-llama/Meta-Llama-3-8B-Instruct"
,
enforce_eager
=
True
,
# For faster initialization.
speculative_config
=
{
"method"
:
"eagle"
,
"model"
:
"yuhuili/EAGLE-LLaMA3-Instruct-8B"
,
"num_speculative_tokens"
:
num_speculative_tokens
,
"max_model_len"
:
80
,
},
max_model_len
=
200
,
attention_config
=
{
"backend"
:
attn_backend
},
)
sampling_params
=
SamplingParams
(
max_tokens
=
200
,
ignore_eos
=
True
)
outputs
=
llm
.
generate
(
_PROMPTS
,
sampling_params
)
for
o
in
outputs
:
assert
o
.
outputs
[
0
].
finish_reason
==
"length"
,
(
"This test is only meaningful if the output is truncated due to max length"
)
sampling_params
=
SamplingParams
(
max_tokens
=
200
,
ignore_eos
=
True
)
outputs
=
llm
.
generate
(
_PROMPTS
,
sampling_params
)
for
o
in
outputs
:
assert
o
.
outputs
[
0
].
finish_reason
==
"length"
,
(
"This test is only meaningful if the output "
"is truncated due to max length"
)
sampling_params
=
SamplingParams
(
max_tokens
=
200
,
structured_outputs
=
StructuredOutputsParams
(
regex
=
"^"
+
"a b c d e "
*
15
+
"$"
),
sampling_params
=
SamplingParams
(
max_tokens
=
200
,
structured_outputs
=
StructuredOutputsParams
(
regex
=
"^"
+
"a b c d e "
*
15
+
"$"
),
)
output
=
llm
.
generate
(
_PROMPTS
,
sampling_params
)
for
o
in
output
:
assert
o
.
prompt_token_ids
is
not
None
assert
(
len
(
o
.
prompt_token_ids
)
<
80
<
len
(
o
.
prompt_token_ids
)
+
len
(
o
.
outputs
[
0
].
token_ids
)
<=
200
),
(
"This test is only meaningful if the output "
"is longer than the eagle max length"
)
output
=
llm
.
generate
(
_PROMPTS
,
sampling_params
)
for
o
in
output
:
assert
o
.
prompt_token_ids
is
not
None
assert
(
len
(
o
.
prompt_token_ids
)
<
80
<
len
(
o
.
prompt_token_ids
)
+
len
(
o
.
outputs
[
0
].
token_ids
)
<=
200
),
(
"This test is only meaningful if the output "
"is longer than the eagle max length"
)
assert
o
.
outputs
[
0
].
text
==
"a b c d e "
*
15
assert
o
.
outputs
[
0
].
text
==
"a b c d e "
*
15
tests/v1/structured_output/test_reasoning_structured_output.py
View file @
7e63ef82
...
...
@@ -71,6 +71,7 @@ class TestReasoningStructuredOutput:
request
.
prompt_token_ids
=
[
1
,
2
,
3
,
4
,
5
]
request
.
all_token_ids
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
]
request
.
num_computed_tokens
=
5
request
.
num_output_placeholders
=
0
return
request
def
test_should_fill_bitmask_with_enable_in_reasoning
(
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
7e63ef82
...
...
@@ -6,8 +6,6 @@ import numpy as np
import
pytest
import
torch
from
vllm.attention.backends.abstract
import
MultipleOf
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.attention.layer
import
Attention
from
vllm.config
import
(
AttentionConfig
,
...
...
@@ -27,6 +25,9 @@ from vllm.platforms import current_platform
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils.mem_constants
import
GiB_bytes
from
vllm.utils.system_utils
import
update_environment_variables
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.v1.attention.backend
import
MultipleOf
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.core.kv_cache_utils
import
estimate_max_model_len
,
get_kv_cache_configs
from
vllm.v1.core.sched.output
import
CachedRequestData
,
NewRequestData
,
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
(
...
...
@@ -113,15 +114,16 @@ def get_vllm_config():
@
pytest
.
fixture
def
model_runner
():
vllm_config
=
get_vllm_config
()
model_config
=
vllm_config
.
model_config
num_heads
=
model_config
.
get_num_kv_heads
(
vllm_config
.
parallel_config
)
head_size
=
model_config
.
get_head_size
()
vllm_config
.
compilation_config
.
static_forward_context
[
"layer.0"
]
=
Attention
(
num_heads
,
head_size
,
0.1
)
runner
=
GPUModelRunner
(
vllm_config
,
DEVICE
)
initialize_kv_cache
(
runner
)
return
runner
with
set_current_vllm_config
(
vllm_config
):
model_config
=
vllm_config
.
model_config
num_heads
=
model_config
.
get_num_kv_heads
(
vllm_config
.
parallel_config
)
head_size
=
model_config
.
get_head_size
()
vllm_config
.
compilation_config
.
static_forward_context
[
"layer.0"
]
=
Attention
(
num_heads
,
head_size
,
0.1
)
runner
=
GPUModelRunner
(
vllm_config
,
DEVICE
)
initialize_kv_cache
(
runner
)
yield
runner
model_runner_2
=
model_runner
...
...
@@ -547,7 +549,7 @@ def test_reload_weights_before_load_model(model_runner):
model_runner
.
reload_weights
()
def
test_init_kv_cache_with_kv_sharing_invalid_target_layer_order
():
def
test_init_kv_cache_with_kv_sharing_invalid_target_layer_order
(
default_vllm_config
):
torch
.
set_default_dtype
(
torch
.
float16
)
layer_0
=
"model.layers.0.self_attn.attn"
layer_1
=
"model.layers.1.self_attn.attn"
...
...
@@ -574,7 +576,7 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
assert
fwd_context
is
not
None
def
test_init_kv_cache_with_kv_sharing_target_layer_not_exist
():
def
test_init_kv_cache_with_kv_sharing_target_layer_not_exist
(
default_vllm_config
):
torch
.
set_default_dtype
(
torch
.
float16
)
layer_0
=
"model.layers.0.self_attn.attn"
layer_1
=
"model.layers.1.self_attn.attn"
...
...
@@ -601,7 +603,7 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
assert
fwd_context
is
not
None
def
test_init_kv_cache_with_kv_sharing_target_same_as_current
():
def
test_init_kv_cache_with_kv_sharing_target_same_as_current
(
default_vllm_config
):
torch
.
set_default_dtype
(
torch
.
float16
)
layer_0
=
"model.layers.0.self_attn.attn"
layer_1
=
"model.layers.1.self_attn.attn"
...
...
@@ -628,7 +630,7 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
assert
fwd_context
is
not
None
def
test_init_kv_cache_without_kv_sharing
():
def
test_init_kv_cache_without_kv_sharing
(
default_vllm_config
):
torch
.
set_default_dtype
(
torch
.
float16
)
layer_0
=
"model.layers.0.self_attn.attn"
layer_1
=
"model.layers.1.self_attn.attn"
...
...
@@ -695,7 +697,7 @@ def test_init_kv_cache_without_kv_sharing():
assert
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
[
1
]
==
layer_1
def
test_init_kv_cache_with_kv_sharing_valid
():
def
test_init_kv_cache_with_kv_sharing_valid
(
default_vllm_config
):
torch
.
set_default_dtype
(
torch
.
float16
)
layer_0
=
"model.layers.0.self_attn.attn"
layer_1
=
"model.layers.1.self_attn.attn"
...
...
@@ -778,7 +780,7 @@ def test_hybrid_attention_mamba_tensor_shapes():
will not corrupt an attention block and vice versa
"""
current_platform
.
seed_everything
(
42
)
set_random_seed
(
42
)
update_environment_variables
(
{
...
...
@@ -1048,7 +1050,7 @@ def test_input_batch_with_kernel_block_sizes():
assert
block_table
.
block_size
==
kernel_size
def
test_hybrid_cache_integration
(
model_runner
,
dist_init
):
def
test_hybrid_cache_integration
(
default_vllm_config
,
dist_init
):
"""Test hybrid cache architecture integration with GPUModelRunner."""
# Create a new model runner with hybrid cache configuration
vllm_config
=
get_vllm_config
()
...
...
@@ -1112,3 +1114,87 @@ def test_hybrid_cache_integration(model_runner, dist_init):
runner
.
_update_states
(
scheduler_output
)
assert
_is_req_scheduled
(
runner
,
req_id
)
assert
_is_req_state_block_table_match
(
runner
,
req_id
)
def
test_is_uniform_decode
()
->
None
:
# Normal
assert
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
1
,
uniform_decode_query_len
=
1
,
num_tokens
=
16
,
num_reqs
=
16
,
)
assert
not
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
2
,
uniform_decode_query_len
=
1
,
num_tokens
=
16
,
num_reqs
=
16
,
)
assert
not
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
1
,
uniform_decode_query_len
=
1
,
num_tokens
=
16
,
num_reqs
=
15
,
)
# Spec decoding
assert
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
5
,
uniform_decode_query_len
=
5
,
num_tokens
=
30
,
num_reqs
=
6
,
)
assert
not
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
5
,
uniform_decode_query_len
=
4
,
num_tokens
=
30
,
num_reqs
=
6
,
)
assert
not
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
5
,
uniform_decode_query_len
=
5
,
num_tokens
=
30
,
num_reqs
=
7
,
)
# Force uniform decode
assert
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
1
,
uniform_decode_query_len
=
1
,
num_tokens
=
16
,
num_reqs
=
16
,
force_uniform_decode
=
True
,
)
assert
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
2
,
uniform_decode_query_len
=
1
,
num_tokens
=
16
,
num_reqs
=
16
,
force_uniform_decode
=
True
,
)
assert
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
1
,
uniform_decode_query_len
=
1
,
num_tokens
=
16
,
num_reqs
=
15
,
force_uniform_decode
=
True
,
)
assert
not
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
1
,
uniform_decode_query_len
=
1
,
num_tokens
=
16
,
num_reqs
=
16
,
force_uniform_decode
=
False
,
)
assert
not
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
2
,
uniform_decode_query_len
=
1
,
num_tokens
=
16
,
num_reqs
=
16
,
force_uniform_decode
=
False
,
)
assert
not
GPUModelRunner
.
_is_uniform_decode
(
max_num_scheduled_tokens
=
1
,
uniform_decode_query_len
=
1
,
num_tokens
=
16
,
num_reqs
=
15
,
force_uniform_decode
=
False
,
)
Prev
1
…
29
30
31
32
33
34
35
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment