Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8d75f22e
Commit
8d75f22e
authored
Dec 13, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori
parents
ce888aa4
7d80c73d
Changes
679
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
930 additions
and
198 deletions
+930
-198
tests/v1/determinism/utils.py
tests/v1/determinism/utils.py
+9
-2
tests/v1/distributed/test_dbo.py
tests/v1/distributed/test_dbo.py
+19
-0
tests/v1/e2e/test_async_scheduling.py
tests/v1/e2e/test_async_scheduling.py
+2
-0
tests/v1/e2e/test_async_spec_decode.py
tests/v1/e2e/test_async_spec_decode.py
+131
-0
tests/v1/e2e/test_kv_sharing_fast_prefill.py
tests/v1/e2e/test_kv_sharing_fast_prefill.py
+19
-3
tests/v1/e2e/test_spec_decode.py
tests/v1/e2e/test_spec_decode.py
+9
-5
tests/v1/ec_connector/integration/run_epd_correctness_test.sh
...s/v1/ec_connector/integration/run_epd_correctness_test.sh
+4
-4
tests/v1/ec_connector/unit/test_ec_example_connector.py
tests/v1/ec_connector/unit/test_ec_example_connector.py
+43
-43
tests/v1/engine/test_abort_final_step.py
tests/v1/engine/test_abort_final_step.py
+311
-0
tests/v1/engine/test_engine_args.py
tests/v1/engine/test_engine_args.py
+16
-0
tests/v1/engine/test_engine_core.py
tests/v1/engine/test_engine_core.py
+2
-2
tests/v1/entrypoints/openai/test_completion_with_image_embeds.py
...1/entrypoints/openai/test_completion_with_image_embeds.py
+2
-15
tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
+21
-2
tests/v1/kv_connector/unit/test_backwards_compatibility.py
tests/v1/kv_connector/unit/test_backwards_compatibility.py
+4
-4
tests/v1/kv_connector/unit/test_example_connector.py
tests/v1/kv_connector/unit/test_example_connector.py
+12
-3
tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py
tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py
+5
-5
tests/v1/kv_connector/unit/test_lmcache_integration.py
tests/v1/kv_connector/unit/test_lmcache_integration.py
+2
-60
tests/v1/kv_connector/unit/test_multi_connector.py
tests/v1/kv_connector/unit/test_multi_connector.py
+11
-11
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+287
-32
tests/v1/kv_connector/unit/utils.py
tests/v1/kv_connector/unit/utils.py
+21
-7
No files found.
Too many changes to show.
To preserve performance only
679 of 679+
files are displayed.
Plain diff
Email patch
tests/v1/determinism/utils.py
View file @
8d75f22e
...
...
@@ -11,12 +11,15 @@ from vllm.platforms import current_platform
from
vllm.utils.flashinfer
import
has_flashinfer
skip_unsupported
=
pytest
.
mark
.
skipif
(
not
(
current_platform
.
is_cuda
()
and
current_platform
.
has_device_capability
(
90
)),
reason
=
"Requires CUDA and >= Hopper (SM90)"
,
not
(
current_platform
.
is_cuda
()
and
current_platform
.
has_device_capability
(
80
)),
# Supports testing on Ampere and Ada Lovelace devices.
# Note: For devices with SM < 90, batch invariance does not support CUDA Graphs.
reason
=
"Requires CUDA and >= Ampere (SM80)"
,
)
BACKENDS
:
list
[
str
]
=
[
"FLASH_ATTN"
,
"TRITON_MLA"
,
]
if
has_flashinfer
():
...
...
@@ -96,3 +99,7 @@ def _extract_step_logprobs(request_output):
return
t
,
inner
.
token_ids
return
None
,
None
def
is_device_capability_below_90
()
->
bool
:
return
not
current_platform
.
has_device_capability
(
90
)
tests/v1/distributed/test_dbo.py
View file @
8d75f22e
...
...
@@ -9,10 +9,22 @@ correctly with the DeepSeek-V2-Lite model using GSM8K evaluation.
"""
import
pytest
import
torch
from
tests.evals.gsm8k.gsm8k_eval
import
evaluate_gsm8k
from
tests.utils
import
RemoteOpenAIServer
# Detect Blackwell / B200 (compute capability 10.x)
try
:
if
torch
.
cuda
.
is_available
():
cap
=
torch
.
cuda
.
get_device_capability
(
0
)
IS_BLACKWELL
=
cap
[
0
]
>=
10
else
:
IS_BLACKWELL
=
False
except
Exception
:
# Be conservative: if we can't detect, don't xfail by default
IS_BLACKWELL
=
False
MODEL_NAME
=
"deepseek-ai/DeepSeek-V2-Lite-Chat"
DP_SIZE
=
2
...
...
@@ -33,6 +45,13 @@ DEEPEP_BACKENDS = [
@
pytest
.
mark
.
parametrize
(
"all2all_backend"
,
DEEPEP_BACKENDS
)
@
pytest
.
mark
.
xfail
(
IS_BLACKWELL
,
reason
=
(
"Temporary: DBO accuracy unstable on Blackwell "
"(doesn't meet expectation of MIN_ACCURACY = 0.62)"
),
)
def
test_dbo_dp_ep_gsm8k
(
all2all_backend
:
str
,
num_gpus_available
):
"""
Test DBO with DP+EP using GSM8K evaluation.
...
...
tests/v1/e2e/test_async_scheduling.py
View file @
8d75f22e
...
...
@@ -124,6 +124,8 @@ def run_tests(
with
monkeypatch
.
context
()
as
m
:
# avoid precision errors
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"FLEX_ATTENTION"
)
# lock matmul precision to full FP32
m
.
setenv
(
"VLLM_FLOAT32_MATMUL_PRECISION"
,
"highest"
)
# m.setenv("VLLM_BATCH_INVARIANT", "1")
outputs
:
list
[
tuple
[
str
,
list
,
list
]]
=
[]
for
n
,
(
...
...
tests/v1/e2e/test_async_spec_decode.py
0 → 100644
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test that verifies no implicit GPU-CPU synchronization occurs during
speculative decoding generation under expected conditions.
"""
import
multiprocessing
import
sys
import
traceback
import
pytest
import
torch
@
pytest
.
fixture
def
sync_tracker
():
"""
Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect
lazy init syncs. Prints stack traces immediately when syncs occur.
"""
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
# Shared counter for cross-process communication (inherited by fork)
sync_count
=
multiprocessing
.
Value
(
"i"
,
0
)
# Save original property
original_prop
=
CommonAttentionMetadata
.
seq_lens_cpu
original_fget
=
original_prop
.
fget
# Create tracking wrapper
def
tracking_seq_lens_cpu
(
self
):
if
self
.
_seq_lens_cpu
is
None
:
# Increment counter
with
sync_count
.
get_lock
():
sync_count
.
value
+=
1
count
=
sync_count
.
value
# Print stack trace immediately (shows in subprocess output)
print
(
f
"
\n
{
'='
*
60
}
"
,
file
=
sys
.
stderr
)
print
(
f
"SYNC #
{
count
}
: seq_lens_cpu lazy init triggered!"
,
file
=
sys
.
stderr
)
print
(
f
"
{
'='
*
60
}
"
,
file
=
sys
.
stderr
)
traceback
.
print_stack
(
file
=
sys
.
stderr
)
print
(
f
"
{
'='
*
60
}
\n
"
,
file
=
sys
.
stderr
)
sys
.
stderr
.
flush
()
return
original_fget
(
self
)
# Apply patch
CommonAttentionMetadata
.
seq_lens_cpu
=
property
(
tracking_seq_lens_cpu
)
class
SyncTracker
:
@
property
def
count
(
self
)
->
int
:
return
sync_count
.
value
def
assert_no_sync
(
self
,
msg
:
str
=
""
):
count
=
sync_count
.
value
assert
count
==
0
,
(
f
"Unexpected GPU-CPU sync: seq_lens_cpu lazy init triggered "
f
"
{
count
}
times. See stack traces above.
{
msg
}
"
)
yield
SyncTracker
()
# Restore original property
CommonAttentionMetadata
.
seq_lens_cpu
=
original_prop
torch
.
_dynamo
.
reset
()
# Test configurations: (model, spec_model, method, num_spec_tokens, backend_env)
SPEC_DECODE_CONFIGS
=
[
pytest
.
param
(
"meta-llama/Llama-3.2-1B-Instruct"
,
"nm-testing/Llama3_2_1B_speculator.eagle3"
,
"eagle3"
,
2
,
id
=
"eagle3-llama"
,
),
pytest
.
param
(
"eagle618/deepseek-v3-random"
,
"eagle618/eagle-deepseek-v3-random"
,
"eagle"
,
2
,
id
=
"eagle-mla-deepseek"
,
),
]
@
pytest
.
mark
.
parametrize
(
"model,spec_model,method,num_spec_tokens"
,
SPEC_DECODE_CONFIGS
,
)
def
test_no_sync_with_spec_decode
(
sync_tracker
,
model
:
str
,
spec_model
:
str
,
method
:
str
,
num_spec_tokens
:
int
,
):
"""
Test that no implicit GPU-CPU sync occurs during speculative decoding
generation.
"""
# Import vLLM AFTER sync_tracker fixture has applied the patch
from
vllm
import
LLM
,
SamplingParams
from
vllm.distributed
import
cleanup_dist_env_and_memory
llm
=
LLM
(
model
=
model
,
max_model_len
=
256
,
speculative_config
=
{
"method"
:
method
,
"num_speculative_tokens"
:
num_spec_tokens
,
"model"
:
spec_model
,
},
enforce_eager
=
True
,
async_scheduling
=
True
,
)
outputs
=
llm
.
generate
(
[
"Hello, my name is"
],
SamplingParams
(
temperature
=
0
,
max_tokens
=
10
),
)
assert
len
(
outputs
)
==
1
assert
len
(
outputs
[
0
].
outputs
[
0
].
text
)
>
0
del
llm
torch
.
cuda
.
empty_cache
()
cleanup_dist_env_and_memory
()
sync_tracker
.
assert_no_sync
()
tests/v1/e2e/test_kv_sharing_fast_prefill.py
View file @
8d75f22e
...
...
@@ -7,6 +7,7 @@ import pytest
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
CompilationConfig
,
CompilationMode
from
vllm.platforms
import
current_platform
from
...utils
import
check_answers
,
fork_new_process_for_each_test
,
prep_prompts
...
...
@@ -43,15 +44,26 @@ def test_prompts():
return
prompts
@
fork_new_process_for_each_test
use_fork_for_test
=
(
fork_new_process_for_each_test
if
not
current_platform
.
is_rocm
()
else
lambda
x
:
x
)
@
use_fork_for_test
@
pytest
.
mark
.
parametrize
(
"kv_sharing_fast_prefill"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
,
False
])
def
test_kv_sharing_fast_prefill
(
monkeypatch
:
pytest
.
MonkeyPatch
,
kv_sharing_fast_prefill
:
bool
,
enforce_eager
:
bool
,
test_prompts
:
list
[
str
],
):
if
not
enforce_eager
and
current_platform
.
is_rocm
():
# Relevant context: https://github.com/vllm-project/vllm/pull/29244
pytest
.
skip
(
"ROCm: torch.compile produces incorrect output for gemma-3n's GELU "
"with tanh approximation. Use enforce_eager=True instead."
)
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
100
)
compilation_config
=
CompilationConfig
(
# This allows vLLM compilation backend to handle allocating and
...
...
@@ -65,6 +77,10 @@ def test_kv_sharing_fast_prefill(
with
monkeypatch
.
context
()
as
m
:
# Make scheduling deterministic for reproducibility
if
current_platform
.
is_rocm
():
# Use spawn to prevent cuda re-initialization error
m
.
setenv
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
"spawn"
)
else
:
m
.
setenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"0"
)
prompts
,
answer
,
indices
=
prep_prompts
(
batch_size
)
...
...
tests/v1/e2e/test_spec_decode.py
View file @
8d75f22e
...
...
@@ -191,8 +191,8 @@ def test_suffix_decoding_acceptance(
# Expect the acceptance rate to improve.
assert
first_accept_rate
<
last_accept_rate
# Heuristic: expect at least 8
2.5
% acceptance rate at the end.
assert
last_accept_rate
>
0.8
25
# Heuristic: expect at least 8
0.0
% acceptance rate at the end.
assert
last_accept_rate
>
0.8
0
del
spec_llm
torch
.
cuda
.
empty_cache
()
...
...
@@ -402,7 +402,11 @@ def test_eagle_correctness(
# Scout requires default backend selection
# because vision encoder has head_dim 88 being incompatible
# with FLASH_ATTN and needs to fall back to Flex Attn
pass
# pass if not ROCm
if
current_platform
.
is_rocm
():
# TODO: Enable Flex Attn for spec_decode on ROCm
pytest
.
skip
(
"Flex Attn for spec_decode not supported on ROCm currently"
)
else
:
m
.
setenv
(
"VLLM_MLA_DISABLE"
,
"1"
)
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
attn_backend
)
...
...
@@ -413,9 +417,9 @@ def test_eagle_correctness(
"multi-token eagle spec decode on current platform"
)
if
attn_backend
==
"
FLASH_ATTN
"
and
current_platform
.
is_rocm
():
if
attn_backend
==
"
ROCM_AITER_FA
"
and
current_platform
.
is_rocm
():
if
"deepseek"
in
model_setup
[
1
].
lower
():
pytest
.
skip
(
"
FLASH_ATTN
for deepseek not supported on ROCm platform"
)
pytest
.
skip
(
"
ROCM_AITER_FA
for deepseek not supported on ROCm platform"
)
else
:
m
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
...
...
tests/v1/ec_connector/integration/run_epd_correctness_test.sh
View file @
8d75f22e
...
...
@@ -148,7 +148,7 @@ run_epd_1e_1pd() {
--max-num-seqs
128
\
--allowed-local-media-path
${
GIT_ROOT
}
/tests/v1/ec_connector/integration
\
--ec-transfer-config
'{
"ec_connector": "EC
SharedStorag
eConnector",
"ec_connector": "EC
Exampl
eConnector",
"ec_role": "ec_producer",
"ec_connector_extra_config": {
"shared_storage_path": "'
"
$EC_SHARED_STORAGE_PATH
"
'"
...
...
@@ -167,7 +167,7 @@ run_epd_1e_1pd() {
--max-num-seqs
128
\
--allowed-local-media-path
${
GIT_ROOT
}
/tests/v1/ec_connector/integration
\
--ec-transfer-config
'{
"ec_connector": "EC
SharedStorag
eConnector",
"ec_connector": "EC
Exampl
eConnector",
"ec_role": "ec_consumer",
"ec_connector_extra_config": {
"shared_storage_path": "'
"
$EC_SHARED_STORAGE_PATH
"
'"
...
...
@@ -348,7 +348,7 @@ run_epd_1e_1p_1d() {
--max-num-seqs
128
\
--allowed-local-media-path
${
GIT_ROOT
}
/tests/v1/ec_connector/integration
\
--ec-transfer-config
'{
"ec_connector": "EC
SharedStorag
eConnector",
"ec_connector": "EC
Exampl
eConnector",
"ec_role": "ec_producer",
"ec_connector_extra_config": {
"shared_storage_path": "'
"
$EC_SHARED_STORAGE_PATH
"
'"
...
...
@@ -369,7 +369,7 @@ run_epd_1e_1p_1d() {
--max-num-seqs
128
\
--allowed-local-media-path
${
GIT_ROOT
}
/tests/v1/ec_connector/integration
\
--ec-transfer-config
'{
"ec_connector": "EC
SharedStorag
eConnector",
"ec_connector": "EC
Exampl
eConnector",
"ec_role": "ec_consumer",
"ec_connector_extra_config": {
"shared_storage_path": "'
"
$EC_SHARED_STORAGE_PATH
"
'"
...
...
tests/v1/ec_connector/unit/test_ec_
shared_storag
e_connector.py
→
tests/v1/ec_connector/unit/test_ec_
exampl
e_connector.py
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for EC
SharedStorag
eConnector.
Unit tests for EC
Exampl
eConnector.
"""
import
os
...
...
@@ -13,9 +13,9 @@ import torch
from
vllm.config
import
VllmConfig
from
vllm.distributed.ec_transfer.ec_connector.base
import
ECConnectorRole
from
vllm.distributed.ec_transfer.ec_connector.
shared_storag
e_connector
import
(
EC
SharedStorag
eConnector
,
EC
SharedStorag
eConnectorMetadata
,
from
vllm.distributed.ec_transfer.ec_connector.
exampl
e_connector
import
(
EC
Exampl
eConnector
,
EC
Exampl
eConnectorMetadata
,
MMMeta
,
)
from
vllm.multimodal.inputs
import
MultiModalFeatureSpec
,
PlaceholderRange
...
...
@@ -81,12 +81,12 @@ def mock_request_with_3_mm():
# ------------------ Unit Tests ------------------ #
class
TestEC
SharedStorag
eConnectorBasics
:
class
TestEC
Exampl
eConnectorBasics
:
"""Test basic EC connector functionality."""
def
test_initialization_producer
(
self
,
mock_vllm_config_producer
,
temp_storage
):
"""Test connector initializes correctly as producer."""
connector
=
EC
SharedStorag
eConnector
(
connector
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_producer
,
role
=
ECConnectorRole
.
SCHEDULER
,
)
...
...
@@ -98,7 +98,7 @@ class TestECSharedStorageConnectorBasics:
def
test_initialization_consumer
(
self
,
mock_vllm_config_consumer
,
temp_storage
):
"""Test connector initializes correctly as consumer."""
connector
=
EC
SharedStorag
eConnector
(
connector
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_consumer
,
role
=
ECConnectorRole
.
WORKER
,
)
...
...
@@ -109,11 +109,11 @@ class TestECSharedStorageConnectorBasics:
def
test_role_assignment
(
self
,
mock_vllm_config_producer
):
"""Test role is correctly assigned."""
scheduler_connector
=
EC
SharedStorag
eConnector
(
scheduler_connector
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_producer
,
role
=
ECConnectorRole
.
SCHEDULER
,
)
worker_connector
=
EC
SharedStorag
eConnector
(
worker_connector
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_producer
,
role
=
ECConnectorRole
.
WORKER
,
)
...
...
@@ -133,7 +133,7 @@ class TestCacheExistence:
):
"""Test has_caches returns True when all 3 caches exist."""
# Test for producer first
producer
=
EC
SharedStorag
eConnector
(
producer
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_producer
,
role
=
ECConnectorRole
.
SCHEDULER
,
)
...
...
@@ -154,7 +154,7 @@ class TestCacheExistence:
assert
all
(
producer_result
),
f
"Expected all True, got
{
producer_result
}
"
# Also test consumer can check if cache exists
consumer
=
EC
SharedStorag
eConnector
(
consumer
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_consumer
,
role
=
ECConnectorRole
.
SCHEDULER
,
)
...
...
@@ -170,7 +170,7 @@ class TestCacheExistence:
self
,
mock_vllm_config_producer
,
mock_request_with_3_mm
):
"""Test has_caches returns False when no caches exist."""
connector
=
EC
SharedStorag
eConnector
(
connector
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_producer
,
role
=
ECConnectorRole
.
SCHEDULER
,
)
...
...
@@ -186,7 +186,7 @@ class TestCacheExistence:
self
,
mock_vllm_config_producer
,
mock_request_with_3_mm
):
"""Test has_caches with some caches existing (1 of 3)."""
connector
=
EC
SharedStorag
eConnector
(
connector
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_producer
,
role
=
ECConnectorRole
.
SCHEDULER
,
)
...
...
@@ -213,7 +213,7 @@ class TestStateManagement:
self
,
mock_vllm_config_producer
,
mock_request_with_3_mm
):
"""Test state update after allocation for 3 MM items."""
connector
=
EC
SharedStorag
eConnector
(
connector
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_producer
,
role
=
ECConnectorRole
.
SCHEDULER
,
)
...
...
@@ -238,7 +238,7 @@ class TestStateManagement:
self
,
mock_vllm_config_producer
,
mock_request_with_3_mm
):
"""Test metadata building for 3 MM items."""
connector
=
EC
SharedStorag
eConnector
(
connector
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_producer
,
role
=
ECConnectorRole
.
SCHEDULER
,
)
...
...
@@ -252,7 +252,7 @@ class TestStateManagement:
metadata
=
connector
.
build_connector_meta
(
scheduler_output
)
# Assert
assert
isinstance
(
metadata
,
EC
SharedStorag
eConnectorMetadata
)
assert
isinstance
(
metadata
,
EC
Exampl
eConnectorMetadata
)
assert
len
(
metadata
.
mm_datas
)
==
3
assert
metadata
.
mm_datas
[
0
].
mm_hash
==
"img_hash_1"
assert
metadata
.
mm_datas
[
0
].
num_token
==
100
...
...
@@ -266,7 +266,7 @@ class TestStateManagement:
def
test_build_connector_meta_empty
(
self
,
mock_vllm_config_producer
):
"""Test metadata building with empty state."""
connector
=
EC
SharedStorag
eConnector
(
connector
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_producer
,
role
=
ECConnectorRole
.
SCHEDULER
,
)
...
...
@@ -274,14 +274,14 @@ class TestStateManagement:
scheduler_output
=
Mock
(
spec
=
SchedulerOutput
)
metadata
=
connector
.
build_connector_meta
(
scheduler_output
)
assert
isinstance
(
metadata
,
EC
SharedStorag
eConnectorMetadata
)
assert
isinstance
(
metadata
,
EC
Exampl
eConnectorMetadata
)
assert
len
(
metadata
.
mm_datas
)
==
0
def
test_state_cleared_after_metadata_build
(
self
,
mock_vllm_config_producer
,
mock_request_with_3_mm
):
"""Test that state is properly cleared after building metadata."""
connector
=
EC
SharedStorag
eConnector
(
connector
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_producer
,
role
=
ECConnectorRole
.
SCHEDULER
,
)
...
...
@@ -310,7 +310,7 @@ class TestCacheSaving:
self
,
mock_vllm_config_producer
,
mock_request_with_3_mm
,
temp_storage
):
"""Test cache saving as producer for 3 different MM items."""
connector
=
EC
SharedStorag
eConnector
(
connector
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_producer
,
role
=
ECConnectorRole
.
WORKER
,
)
...
...
@@ -336,7 +336,7 @@ class TestCacheSaving:
def
test_save_caches_consumer_skips
(
self
,
mock_vllm_config_consumer
):
"""Test cache saving is skipped for consumer."""
connector
=
EC
SharedStorag
eConnector
(
connector
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_consumer
,
role
=
ECConnectorRole
.
WORKER
,
)
...
...
@@ -366,7 +366,7 @@ class TestCacheLoading:
):
"""Test consumer loads 3 caches from storage."""
# First, create producer to save caches
producer
=
EC
SharedStorag
eConnector
(
producer
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_producer
,
role
=
ECConnectorRole
.
WORKER
,
)
...
...
@@ -379,13 +379,13 @@ class TestCacheLoading:
producer
.
save_caches
(
saved_caches
,
mm_hash
)
# Now consumer loads
consumer
=
EC
SharedStorag
eConnector
(
consumer
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_consumer
,
role
=
ECConnectorRole
.
WORKER
,
)
# Setup metadata for all 3
metadata
=
EC
SharedStorag
eConnectorMetadata
()
metadata
=
EC
Exampl
eConnectorMetadata
()
for
mm_hash
in
mm_hashes
:
metadata
.
add_mm_data
(
MMMeta
.
make_meta
(
mm_hash
,
100
))
consumer
.
bind_connector_metadata
(
metadata
)
...
...
@@ -410,7 +410,7 @@ class TestCacheLoading:
):
"""Test cache loading skips already cached items."""
# Setup: producer saves cache
producer
=
EC
SharedStorag
eConnector
(
producer
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_producer
,
role
=
ECConnectorRole
.
WORKER
,
)
...
...
@@ -420,12 +420,12 @@ class TestCacheLoading:
producer
.
save_caches
({
mm_hash
:
saved_cache
},
mm_hash
)
# Consumer setup
consumer
=
EC
SharedStorag
eConnector
(
consumer
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_consumer
,
role
=
ECConnectorRole
.
WORKER
,
)
metadata
=
EC
SharedStorag
eConnectorMetadata
()
metadata
=
EC
Exampl
eConnectorMetadata
()
metadata
.
add_mm_data
(
MMMeta
.
make_meta
(
mm_hash
,
100
))
consumer
.
bind_connector_metadata
(
metadata
)
...
...
@@ -444,13 +444,13 @@ class TestCacheLoading:
def
test_start_load_caches_empty_metadata
(
self
,
mock_vllm_config_consumer
):
"""Test loading with empty metadata does nothing."""
consumer
=
EC
SharedStorag
eConnector
(
consumer
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_consumer
,
role
=
ECConnectorRole
.
WORKER
,
)
# Setup empty metadata
metadata
=
EC
SharedStorag
eConnectorMetadata
()
metadata
=
EC
Exampl
eConnectorMetadata
()
consumer
.
bind_connector_metadata
(
metadata
)
# Load (should not raise)
...
...
@@ -466,7 +466,7 @@ class TestFilenameGeneration:
def
test_generate_foldername
(
self
,
mock_vllm_config_producer
,
temp_storage
):
"""Test folder name generation."""
connector
=
EC
SharedStorag
eConnector
(
connector
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_producer
,
role
=
ECConnectorRole
.
WORKER
,
)
...
...
@@ -479,7 +479,7 @@ class TestFilenameGeneration:
def
test_generate_filename
(
self
,
mock_vllm_config_producer
,
temp_storage
):
"""Test filename generation."""
connector
=
EC
SharedStorag
eConnector
(
connector
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_producer
,
role
=
ECConnectorRole
.
WORKER
,
)
...
...
@@ -493,7 +493,7 @@ class TestFilenameGeneration:
def
test_generate_filename_consistency
(
self
,
mock_vllm_config_producer
):
"""Test filename generation is consistent."""
connector
=
EC
SharedStorag
eConnector
(
connector
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_producer
,
role
=
ECConnectorRole
.
WORKER
,
)
...
...
@@ -510,12 +510,12 @@ class TestMetadataBindingLifecycle:
def
test_bind_connector_metadata
(
self
,
mock_vllm_config_consumer
):
"""Test binding connector metadata."""
connector
=
EC
SharedStorag
eConnector
(
connector
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_consumer
,
role
=
ECConnectorRole
.
WORKER
,
)
metadata
=
EC
SharedStorag
eConnectorMetadata
()
metadata
=
EC
Exampl
eConnectorMetadata
()
metadata
.
add_mm_data
(
MMMeta
.
make_meta
(
"hash_1"
,
100
))
connector
.
bind_connector_metadata
(
metadata
)
...
...
@@ -524,12 +524,12 @@ class TestMetadataBindingLifecycle:
def
test_clear_connector_metadata
(
self
,
mock_vllm_config_consumer
):
"""Test clearing connector metadata."""
connector
=
EC
SharedStorag
eConnector
(
connector
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_consumer
,
role
=
ECConnectorRole
.
WORKER
,
)
metadata
=
EC
SharedStorag
eConnectorMetadata
()
metadata
=
EC
Exampl
eConnectorMetadata
()
connector
.
bind_connector_metadata
(
metadata
)
connector
.
clear_connector_metadata
()
...
...
@@ -538,12 +538,12 @@ class TestMetadataBindingLifecycle:
def
test_get_connector_metadata
(
self
,
mock_vllm_config_consumer
):
"""Test getting connector metadata."""
connector
=
EC
SharedStorag
eConnector
(
connector
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_consumer
,
role
=
ECConnectorRole
.
WORKER
,
)
metadata
=
EC
SharedStorag
eConnectorMetadata
()
metadata
=
EC
Exampl
eConnectorMetadata
()
connector
.
bind_connector_metadata
(
metadata
)
retrieved
=
connector
.
_get_connector_metadata
()
...
...
@@ -552,7 +552,7 @@ class TestMetadataBindingLifecycle:
def
test_get_connector_metadata_not_set
(
self
,
mock_vllm_config_consumer
):
"""Test getting metadata when not set raises."""
connector
=
EC
SharedStorag
eConnector
(
connector
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_consumer
,
role
=
ECConnectorRole
.
WORKER
,
)
...
...
@@ -566,7 +566,7 @@ class TestEdgeCases:
def
test_save_empty_cache
(
self
,
mock_vllm_config_producer
):
"""Test saving empty tensor."""
connector
=
EC
SharedStorag
eConnector
(
connector
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_producer
,
role
=
ECConnectorRole
.
WORKER
,
)
...
...
@@ -579,12 +579,12 @@ class TestEdgeCases:
def
test_load_nonexistent_cache
(
self
,
mock_vllm_config_consumer
):
"""Test loading cache that doesn't exist raises error."""
connector
=
EC
SharedStorag
eConnector
(
connector
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_consumer
,
role
=
ECConnectorRole
.
WORKER
,
)
metadata
=
EC
SharedStorag
eConnectorMetadata
()
metadata
=
EC
Exampl
eConnectorMetadata
()
metadata
.
add_mm_data
(
MMMeta
.
make_meta
(
"nonexistent_hash"
,
100
))
connector
.
bind_connector_metadata
(
metadata
)
...
...
@@ -596,7 +596,7 @@ class TestEdgeCases:
def
test_has_caches_empty_request
(
self
,
mock_vllm_config_producer
):
"""Test has_caches with request that has no MM data."""
connector
=
EC
SharedStorag
eConnector
(
connector
=
EC
Exampl
eConnector
(
vllm_config
=
mock_vllm_config_producer
,
role
=
ECConnectorRole
.
SCHEDULER
,
)
...
...
tests/v1/engine/test_abort_final_step.py
0 → 100644
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test for the fix in PR #29987: Eagerly abort cancelled final-step requests.
This test verifies that when a request is aborted during its final execution
step (when it would naturally complete), it is properly marked as aborted
rather than being treated as normally completed.
The test uses a dummy KV connector to verify that the connector receives
the correct finish status (FINISHED_ABORTED, not FINISHED_LENGTH_CAPPED).
"""
import
asyncio
import
tempfile
import
time
from
pathlib
import
Path
from
typing
import
Any
from
unittest.mock
import
patch
import
pytest
from
vllm
import
SamplingParams
from
vllm.config
import
KVTransferConfig
,
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.factory
import
KVConnectorFactory
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
,
)
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.platforms
import
current_platform
from
vllm.sampling_params
import
RequestOutputKind
from
vllm.utils.torch_utils
import
set_default_torch_num_threads
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.engine.async_llm
import
AsyncLLM
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.request
import
Request
if
not
current_platform
.
is_cuda
():
pytest
.
skip
(
reason
=
"V1 currently only supported on CUDA."
,
allow_module_level
=
True
)
TEXT_PROMPT
=
"Hello"
class
DummyKVConnectorMetadata
(
KVConnectorMetadata
):
"""Dummy metadata for the test connector."""
def
__init__
(
self
):
self
.
requests
:
list
=
[]
class
DummyKVConnector
(
KVConnectorBase_V1
):
"""
Dummy KV connector that captures request finish statuses to a file.
This is used to verify the fix - without the fix, a request aborted
during its final step would be captured as FINISHED_LENGTH_CAPPED
instead of FINISHED_ABORTED.
The connector runs in a separate process, so we write statuses to a file
that can be read by the test process.
"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
role
:
KVConnectorRole
,
kv_cache_config
:
KVCacheConfig
|
None
=
None
,
):
super
().
__init__
(
vllm_config
,
role
,
kv_cache_config
)
# Get the status file path from extra config
extra_config
=
vllm_config
.
kv_transfer_config
.
kv_connector_extra_config
or
{}
self
.
status_file
=
extra_config
.
get
(
"status_file"
)
# Log that we were initialized
if
self
.
status_file
:
try
:
with
open
(
self
.
status_file
,
"a"
)
as
f
:
f
.
write
(
f
"INIT:
{
role
.
name
}
\n
"
)
except
Exception
:
pass
def
get_num_new_matched_tokens
(
self
,
request
:
Request
,
num_computed_tokens
:
int
,
)
->
tuple
[
int
|
None
,
bool
]:
return
(
0
,
False
)
def
update_state_after_alloc
(
self
,
request
:
Request
,
blocks
:
Any
,
num_external_tokens
:
int
,
):
pass
def
build_connector_meta
(
self
,
scheduler_output
:
SchedulerOutput
)
->
KVConnectorMetadata
:
return
DummyKVConnectorMetadata
()
def
request_finished
(
self
,
request
:
Request
,
block_ids
:
list
[
int
],
)
->
tuple
[
bool
,
dict
[
str
,
Any
]
|
None
]:
"""Capture the request status when finished by writing to a file."""
if
self
.
status_file
:
try
:
with
open
(
self
.
status_file
,
"a"
)
as
f
:
# Write the status name (e.g., "FINISHED_ABORTED")
f
.
write
(
f
"
{
request
.
status
.
name
}
\n
"
)
except
Exception
as
e
:
# Log but don't fail - this is just test instrumentation
print
(
f
"[DummyKVConnector] Failed to write status:
{
e
}
"
)
return
False
,
None
def
start_load_kv
(
self
,
forward_context
:
Any
,
**
kwargs
:
Any
)
->
None
:
pass
def
wait_for_layer_load
(
self
,
layer_name
:
str
)
->
None
:
pass
def
save_kv_layer
(
self
,
layer_name
:
str
,
kv_layer
:
Any
,
attn_metadata
:
Any
,
**
kwargs
:
Any
,
)
->
None
:
pass
def
wait_for_save
(
self
):
pass
# Register the dummy connector
KVConnectorFactory
.
register_connector
(
"DummyKVConnector"
,
__name__
,
DummyKVConnector
.
__name__
)
@
pytest
.
mark
.
parametrize
(
"async_scheduling"
,
[
False
,
True
])
@
pytest
.
mark
.
asyncio
async
def
test_abort_during_final_step
(
async_scheduling
:
bool
):
"""
Test that a request aborted during its final execution step is treated as
aborted rather than completed.
This test:
1. Monkeypatches execute_model to wait for a file to be deleted
2. Configures a dummy KV connector to capture finish statuses
3. Starts a request with max_tokens=1 (will complete on first decode step)
4. Aborts the request, then deletes the file to unblock execute_model
5. Verifies the KV connector received FINISHED_ABORTED not FINISHED_LENGTH_CAPPED
See https://github.com/vllm-project/vllm/pull/29987.
Without the fix, the KV connector would see FINISHED_LENGTH_CAPPED because
update_from_output() would mark the request as completed before processing
the abort. This causes KV cache blocks to not be freed properly in
disaggregated prefill scenarios.
With the fix, _process_aborts_queue() runs before update_from_output(), so the
abort takes precedence and the KV connector sees FINISHED_ABORTED.
"""
# Create three temporary files:
# 1. ready_file: deleted by execute_model to signal it has started
# 2. block_file: execute_model waits for this to be deleted
# 3. status_file: KV connector writes finish statuses here
with
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
as
f
:
ready_file
=
Path
(
f
.
name
)
with
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
as
f2
:
block_file
=
Path
(
f2
.
name
)
with
tempfile
.
NamedTemporaryFile
(
delete
=
False
,
mode
=
"w"
)
as
f3
:
status_file
=
Path
(
f3
.
name
)
try
:
# Get the original execute_model method
from
vllm.v1.worker.gpu_worker
import
Worker
original_execute_model
=
Worker
.
execute_model
def
execute_model_with_wait
(
self
,
scheduler_output
):
# Signal that execute_model has been called by deleting ready_file
if
ready_file
.
exists
():
ready_file
.
unlink
()
# Wait for the block file to be deleted (triggered from test after abort)
# This runs in the worker process (after fork), so we poll the filesystem
while
block_file
.
exists
():
time
.
sleep
(
0.01
)
return
original_execute_model
(
self
,
scheduler_output
)
# Patch execute_model to inject the wait
# This happens before the worker process is forked, so the patch applies there
with
patch
.
object
(
Worker
,
"execute_model"
,
execute_model_with_wait
):
request_id
=
"test-abort-final-step"
# Configure engine with dummy KV connector
# Pass the status file path so the connector can write to it
kv_transfer_config
=
KVTransferConfig
(
kv_connector
=
"DummyKVConnector"
,
kv_role
=
"kv_both"
,
kv_connector_extra_config
=
{
"status_file"
:
str
(
status_file
)},
)
engine_args
=
AsyncEngineArgs
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
,
enforce_eager
=
True
,
async_scheduling
=
async_scheduling
,
kv_transfer_config
=
kv_transfer_config
,
)
with
set_default_torch_num_threads
(
1
):
engine
=
AsyncLLM
.
from_engine_args
(
engine_args
)
try
:
# Create a request that will complete after just 1 token
sampling_params
=
SamplingParams
(
max_tokens
=
1
,
ignore_eos
=
True
,
output_kind
=
RequestOutputKind
.
DELTA
,
)
# Start generation in a task
outputs
=
[]
async
def
generate
():
async
for
output
in
engine
.
generate
(
request_id
=
request_id
,
prompt
=
TEXT_PROMPT
,
sampling_params
=
sampling_params
,
):
outputs
.
append
(
output
)
gen_task
=
asyncio
.
create_task
(
generate
())
# Wait for execute_model to signal it has started (with timeout)
timeout
=
5.0
# 5 second timeout
start_time
=
time
.
time
()
while
ready_file
.
exists
():
if
time
.
time
()
-
start_time
>
timeout
:
raise
TimeoutError
(
"Timeout waiting for execute_model to start. "
"The monkeypatch may not be working correctly, "
"for example if spawn was used instead of fork."
)
await
asyncio
.
sleep
(
0.01
)
# Abort the request while execute_model is blocked
await
engine
.
abort
(
request_id
)
# Now unblock execute_model by deleting the file
# The abort should be processed before the model output
block_file
.
unlink
()
# Wait for generation to complete
await
gen_task
# Give the scheduler a moment to finish cleanup
await
asyncio
.
sleep
(
0.1
)
# Verify we got output
assert
len
(
outputs
)
>
0
,
"Should have received at least one output"
# The final output should have finish_reason="abort"
final_output
=
outputs
[
-
1
]
assert
final_output
.
finished
,
(
"Final output should be marked as finished"
)
assert
final_output
.
outputs
[
0
].
finish_reason
==
"abort"
,
(
f
"Expected finish_reason='abort' but got "
f
"'
{
final_output
.
outputs
[
0
].
finish_reason
}
'. "
)
with
open
(
status_file
)
as
f4
:
status_lines
=
f4
.
read
().
strip
().
split
(
"
\n
"
)
# Filter for actual finish statuses (not INIT or empty lines)
captured_statuses
=
[
line
for
line
in
status_lines
if
line
and
line
.
startswith
(
"FINISHED_"
)
]
assert
len
(
captured_statuses
)
>=
1
,
(
f
"Expected at least 1 captured finish status, got "
f
"
{
len
(
captured_statuses
)
}
. File content:
{
status_lines
}
"
)
assert
"FINISHED_ABORTED"
in
captured_statuses
,
(
f
"KV connector should see FINISHED_ABORTED but got "
f
"
{
captured_statuses
}
. "
)
# Verify cleanup
assert
not
engine
.
output_processor
.
has_unfinished_requests
()
finally
:
# Shutdown the engine
engine
.
shutdown
()
finally
:
# Clean up temporary files if they still exist
if
ready_file
.
exists
():
ready_file
.
unlink
()
if
block_file
.
exists
():
block_file
.
unlink
()
if
status_file
.
exists
():
status_file
.
unlink
()
tests/v1/engine/test_engine_args.py
View file @
8d75f22e
...
...
@@ -9,6 +9,7 @@ from vllm.config import VllmConfig
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.hashing
import
_xxhash
def
test_prefix_caching_from_cli
():
...
...
@@ -48,6 +49,21 @@ def test_prefix_caching_from_cli():
args
=
parser
.
parse_args
([
"--prefix-caching-hash-algo"
,
"invalid"
])
@
pytest
.
mark
.
skipif
(
_xxhash
is
None
,
reason
=
"xxhash not installed"
)
def
test_prefix_caching_xxhash_from_cli
():
parser
=
EngineArgs
.
add_cli_args
(
FlexibleArgumentParser
())
# set hash algorithm to xxhash (pickle)
args
=
parser
.
parse_args
([
"--prefix-caching-hash-algo"
,
"xxhash"
])
vllm_config
=
EngineArgs
.
from_cli_args
(
args
=
args
).
create_engine_config
()
assert
vllm_config
.
cache_config
.
prefix_caching_hash_algo
==
"xxhash"
# set hash algorithm to xxhash_cbor
args
=
parser
.
parse_args
([
"--prefix-caching-hash-algo"
,
"xxhash_cbor"
])
vllm_config
=
EngineArgs
.
from_cli_args
(
args
=
args
).
create_engine_config
()
assert
vllm_config
.
cache_config
.
prefix_caching_hash_algo
==
"xxhash_cbor"
def
test_defaults_with_usage_context
():
engine_args
=
EngineArgs
(
model
=
"facebook/opt-125m"
)
vllm_config
:
VllmConfig
=
engine_args
.
create_engine_config
(
UsageContext
.
LLM_CLASS
)
...
...
tests/v1/engine/test_engine_core.py
View file @
8d75f22e
...
...
@@ -507,7 +507,7 @@ def test_encoder_instance_zero_kv_cache(
)
kv_transfer_config
=
(
KVTransferConfig
(
kv_connector
=
"
SharedStorag
eConnector"
,
kv_connector
=
"
Exampl
eConnector"
,
kv_role
=
"kv_both"
,
kv_connector_extra_config
=
{
"shared_storage_path"
:
"local_storage"
},
)
...
...
@@ -515,7 +515,7 @@ def test_encoder_instance_zero_kv_cache(
else
None
)
ec_transfer_config
=
ECTransferConfig
(
ec_connector
=
"EC
SharedStorag
eConnector"
,
ec_connector
=
"EC
Exampl
eConnector"
,
ec_role
=
ec_role
,
ec_connector_extra_config
=
{
"shared_storage_path"
:
"/tmp/ec_test_encoder"
},
)
...
...
tests/v1/entrypoints/openai/test_completion_with_image_embeds.py
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
base64
import
io
import
json
import
openai
# use the official client for correctness check
...
...
@@ -13,6 +11,7 @@ from transformers import AutoConfig
from
tests.conftest
import
ImageTestAssets
from
tests.utils
import
RemoteOpenAIServer
from
vllm.utils.serial_utils
import
tensor2base64
# any model with a chat template should work here
MODEL_NAME
=
"llava-hf/llava-1.5-7b-hf"
...
...
@@ -50,18 +49,6 @@ async def client_with_image_embeds(server_with_image_embeds):
yield
async_client
def
encode_image_embedding_to_base64
(
image_embedding
)
->
str
:
"""
Encode image embedding to base64 string
"""
buffer
=
io
.
BytesIO
()
torch
.
save
(
image_embedding
,
buffer
)
buffer
.
seek
(
0
)
binary_data
=
buffer
.
read
()
base64_image_embedding
=
base64
.
b64encode
(
binary_data
).
decode
(
"utf-8"
)
return
base64_image_embedding
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
half
,
torch
.
float16
,
torch
.
float32
])
...
...
@@ -73,7 +60,7 @@ async def test_completions_with_image_embeds(
):
# Test case: Single image embeds input
image_embeds
=
image_assets
[
0
].
image_embeds
.
to
(
dtype
=
dtype
)
base64_image_embedding
=
en
code_image_embedding_to_
base64
(
image_embeds
)
base64_image_embedding
=
t
en
sor2
base64
(
image_embeds
)
chat_completion
=
await
client_with_image_embeds
.
chat
.
completions
.
create
(
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant."
},
...
...
tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
View file @
8d75f22e
...
...
@@ -30,7 +30,14 @@ async def lifespan(app: FastAPI):
prefiller_base_url
=
f
"http://
{
host
}
:
{
port
}
/v1"
app
.
state
.
prefill_clients
.
append
(
{
"client"
:
httpx
.
AsyncClient
(
timeout
=
None
,
base_url
=
prefiller_base_url
),
"client"
:
httpx
.
AsyncClient
(
timeout
=
None
,
base_url
=
prefiller_base_url
,
limits
=
httpx
.
Limits
(
max_connections
=
None
,
max_keepalive_connections
=
None
,
),
),
"host"
:
host
,
"port"
:
port
,
"id"
:
i
,
...
...
@@ -42,7 +49,14 @@ async def lifespan(app: FastAPI):
decoder_base_url
=
f
"http://
{
host
}
:
{
port
}
/v1"
app
.
state
.
decode_clients
.
append
(
{
"client"
:
httpx
.
AsyncClient
(
timeout
=
None
,
base_url
=
decoder_base_url
),
"client"
:
httpx
.
AsyncClient
(
timeout
=
None
,
base_url
=
decoder_base_url
,
limits
=
httpx
.
Limits
(
max_connections
=
None
,
max_keepalive_connections
=
None
,
),
),
"host"
:
host
,
"port"
:
port
,
"id"
:
i
,
...
...
@@ -169,6 +183,10 @@ async def send_request_to_service(
)
response
.
raise_for_status
()
# read/consume the response body to release the connection
# otherwise, it would http.ReadError
await
response
.
aread
()
return
response
...
...
@@ -206,6 +224,7 @@ async def _handle_completions(api: str, request: Request):
# Extract the needed fields
response_json
=
response
.
json
()
await
response
.
aclose
()
# CRITICAL: Release connection back to pool
kv_transfer_params
=
response_json
.
get
(
"kv_transfer_params"
,
{})
if
kv_transfer_params
:
req_data
[
"kv_transfer_params"
]
=
kv_transfer_params
...
...
tests/v1/kv_connector/unit/test_backwards_compatibility.py
View file @
8d75f22e
...
...
@@ -218,12 +218,12 @@ def test_internal_connector_uses_new_signature():
Test that internal connectors (registered in factory) always use the new
signature and get kv_cache_config.
"""
from
vllm.distributed.kv_transfer.kv_connector.v1.
shared_storag
e_connector
import
(
SharedStorag
eConnector
,
from
vllm.distributed.kv_transfer.kv_connector.v1.
exampl
e_connector
import
(
Exampl
eConnector
,
)
vllm_config
=
create_vllm_config
()
vllm_config
.
kv_transfer_config
.
kv_connector
=
"
SharedStorag
eConnector"
vllm_config
.
kv_transfer_config
.
kv_connector
=
"
Exampl
eConnector"
scheduler
=
create_scheduler
(
vllm_config
)
kv_cache_config
=
scheduler
.
kv_cache_config
...
...
@@ -233,7 +233,7 @@ def test_internal_connector_uses_new_signature():
)
assert
connector
is
not
None
assert
isinstance
(
connector
,
SharedStorag
eConnector
)
assert
isinstance
(
connector
,
Exampl
eConnector
)
assert
connector
.
_kv_cache_config
is
not
None
assert
connector
.
_kv_cache_config
==
kv_cache_config
...
...
tests/v1/kv_connector/unit/test_
shared_storag
e_connector.py
→
tests/v1/kv_connector/unit/test_
exampl
e_connector.py
View file @
8d75f22e
...
...
@@ -3,12 +3,14 @@
from
dataclasses
import
asdict
from
typing
import
NamedTuple
import
pytest
from
PIL
import
Image
from
vllm
import
LLM
,
EngineArgs
,
SamplingParams
from
vllm.assets.image
import
ImageAsset
from
vllm.config
import
KVTransferConfig
from
vllm.multimodal.utils
import
encode_image_base64
from
vllm.platforms
import
current_platform
MODEL_NAME
=
"RedHatAI/Qwen2.5-VL-3B-Instruct-quantized.w8a8"
...
...
@@ -108,18 +110,25 @@ def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]):
print
(
"-"
*
50
)
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
(
"hipErrorLaunchFailure when running this test, see issue:"
"https://github.com/ROCm/pytorch/issues/2822"
),
)
def
test_shared_storage_connector_hashes
(
tmp_path
):
"""
Tests that
SharedStorag
eConnector saves KV to the storage locations
Tests that
Exampl
eConnector saves KV to the storage locations
with proper hashes; that are unique for inputs with identical text but
different images (same size), or same multiple images but different orders.
"""
# Using tmp_path as the storage path to store KV
print
(
f
"KV storage path at:
{
str
(
tmp_path
)
}
"
)
# Configure the
SharedStorag
eConnector
# Configure the
Exampl
eConnector
kv_transfer_config
=
KVTransferConfig
(
kv_connector
=
"
SharedStorag
eConnector"
,
kv_connector
=
"
Exampl
eConnector"
,
kv_role
=
"kv_both"
,
kv_connector_extra_config
=
{
"shared_storage_path"
:
str
(
tmp_path
)},
)
...
...
tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.distributed.kv_transfer.kv_connector.v1.
shared_storag
e_connector
import
(
# noqa: E501
SharedStorag
eConnectorMetadata
,
from
vllm.distributed.kv_transfer.kv_connector.v1.
exampl
e_connector
import
(
# noqa: E501
Exampl
eConnectorMetadata
,
)
from
vllm.distributed.kv_transfer.kv_transfer_state
import
(
ensure_kv_transfer_initialized
,
...
...
@@ -11,7 +11,7 @@ from vllm.distributed.kv_transfer.kv_transfer_state import (
from
vllm.v1.core.sched.output
import
CachedRequestData
,
SchedulerOutput
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
# Importing utils registers Test
SharedStorag
eConnector with the factory
# Importing utils registers Test
Exampl
eConnector with the factory
from
.utils
import
create_vllm_config
...
...
@@ -26,13 +26,13 @@ def _make_empty_scheduler_output():
num_common_prefix_blocks
=
[],
finished_req_ids
=
set
(),
free_encoder_mm_hashes
=
[],
kv_connector_metadata
=
SharedStorag
eConnectorMetadata
(),
kv_connector_metadata
=
Exampl
eConnectorMetadata
(),
)
def
test_kv_connector_mixin_clears_metadata
():
vllm_config
=
create_vllm_config
()
vllm_config
.
kv_transfer_config
.
kv_connector
=
"Test
SharedStorag
eConnector"
vllm_config
.
kv_transfer_config
.
kv_connector
=
"Test
Exampl
eConnector"
vllm_config
.
kv_transfer_config
.
kv_role
=
"kv_both"
vllm_config
.
kv_transfer_config
.
kv_connector_extra_config
[
"name"
]
=
"unit"
...
...
tests/v1/kv_connector/unit/test_lmcache_integration.py
View file @
8d75f22e
...
...
@@ -64,22 +64,6 @@ def test_multimodal_interface():
assumes
(
PlaceholderRange
,
"offset"
)
assumes
(
PlaceholderRange
,
"length"
)
# test a minimal case
import
torch
from
vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils
import
(
apply_mm_hashes_to_token_ids
,
)
token_ids
=
torch
.
arange
(
10
,
dtype
=
torch
.
long
)
mm_hashes
=
[
"0000"
,
"1111"
]
# hex repr of 0 and 4369
mm_positions
=
[
PlaceholderRange
(
offset
=
0
,
length
=
4
),
PlaceholderRange
(
offset
=
5
,
length
=
4
),
]
apply_mm_hashes_to_token_ids
(
token_ids
,
mm_hashes
,
mm_positions
)
assert
token_ids
.
tolist
()
==
[
0
,
0
,
0
,
0
,
4
,
4369
,
4369
,
4369
,
4369
,
9
]
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Requires libcudart.so, not available on ROCm"
...
...
@@ -122,16 +106,6 @@ def test_config_interface():
assumes
(
CacheConfig
,
"block_size"
)
assumes
(
CacheConfig
,
"gpu_memory_utilization"
)
# mla metadata minimal cases
from
vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils
import
(
mla_enabled
,
)
model_config
=
ModelConfig
(
model
=
"deepseek-ai/DeepSeek-R1"
)
assert
mla_enabled
(
model_config
)
model_config
=
ModelConfig
(
model
=
"Qwen/Qwen3-0.6B"
)
assert
not
mla_enabled
(
model_config
)
# kv metadata minimal case
from
vllm.utils.torch_utils
import
get_kv_cache_torch_dtype
...
...
@@ -139,7 +113,7 @@ def test_config_interface():
parallel_config
=
ParallelConfig
()
cache_config
=
CacheConfig
(
cache_dtype
=
"bfloat16"
)
kv_dtype
=
get_kv_cache_torch_dtype
(
cache_config
.
cache_dtype
,
model_config
.
dtype
)
use_mla
=
mla_enabled
(
model_config
)
use_mla
=
False
chunk_size
=
256
num_layer
=
model_config
.
get_num_layers
(
parallel_config
)
num_kv_head
=
model_config
.
get_num_kv_heads
(
parallel_config
)
...
...
@@ -184,43 +158,11 @@ def test_request_interface():
assumes
(
req
,
"num_tokens"
)
assumes
(
req
,
"kv_transfer_params"
,
is_instance_of
=
(
dict
,
NoneType
))
from
vllm.multimodal.inputs
import
MultiModalFeatureSpec
,
MultiModalKwargsItem
from
vllm.multimodal.inputs
import
MultiModalFeatureSpec
assumes
(
MultiModalFeatureSpec
,
"identifier"
)
assumes
(
MultiModalFeatureSpec
,
"mm_position"
)
# minimal case:
from
vllm.multimodal.inputs
import
PlaceholderRange
request
=
Request
(
request_id
=
"test_request"
,
prompt_token_ids
=
[
1
,
2
,
3
],
sampling_params
=
SamplingParams
(
max_tokens
=
10
),
pooling_params
=
None
,
eos_token_id
=
100
,
lora_request
=
None
,
mm_features
=
[
MultiModalFeatureSpec
(
modality
=
"image"
,
identifier
=
"0000"
,
data
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
),
mm_position
=
PlaceholderRange
(
offset
=
0
,
length
=
10
),
)
],
)
from
vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils
import
(
extract_mm_features
,
)
mm_hashes
,
mm_positions
=
extract_mm_features
(
request
)
assert
isinstance
(
mm_hashes
,
list
)
assert
len
(
mm_hashes
)
==
1
assert
isinstance
(
mm_positions
,
list
)
assert
len
(
mm_positions
)
==
1
assert
mm_positions
[
0
].
offset
==
0
assert
mm_positions
[
0
].
length
==
10
def
test_new_request_interface
():
# protect against interface changes
...
...
tests/v1/kv_connector/unit/test_multi_connector.py
View file @
8d75f22e
...
...
@@ -77,9 +77,9 @@ def _compare_directories(dir1: Path, dir2: Path) -> bool:
"https://github.com/ROCm/pytorch/issues/2822"
),
)
def
test_multi_
shared_storag
e_connector_consistency
():
def
test_multi_
exampl
e_connector_consistency
():
"""
Tests that MultiConnector with two
SharedStorag
eConnectors saves
Tests that MultiConnector with two
Exampl
eConnectors saves
identical KV cache data to separate storage locations.
"""
storage_1_path
=
Path
(
"storage_1/"
)
...
...
@@ -89,14 +89,14 @@ def test_multi_shared_storage_connector_consistency():
storage_1_path
.
mkdir
()
storage_2_path
.
mkdir
()
# Configure MultiConnector with two
SharedStorag
eConnectors
# Configure MultiConnector with two
Exampl
eConnectors
kv_transfer_config
=
KVTransferConfig
(
kv_connector
=
"MultiConnector"
,
kv_role
=
"kv_both"
,
kv_connector_extra_config
=
{
"connectors"
:
[
{
"kv_connector"
:
"Test
SharedStorag
eConnector"
,
"kv_connector"
:
"Test
Exampl
eConnector"
,
"kv_role"
:
"kv_both"
,
"kv_connector_extra_config"
:
{
"shared_storage_path"
:
str
(
storage_1_path
),
...
...
@@ -105,7 +105,7 @@ def test_multi_shared_storage_connector_consistency():
"kv_connector_module_path"
:
"tests.v1.kv_connector.unit.utils"
,
},
{
"kv_connector"
:
"Test
SharedStorag
eConnector"
,
"kv_connector"
:
"Test
Exampl
eConnector"
,
"kv_role"
:
"kv_both"
,
"kv_connector_extra_config"
:
{
"shared_storage_path"
:
str
(
storage_2_path
),
...
...
@@ -427,7 +427,7 @@ class TestMultiConnectorStats:
def
test_build_kv_connector_stats_skips_connectors_without_custom_stats
(
self
):
"""Test that connectors without custom stats (return None) are skipped."""
#
SharedStorag
eConnector doesn't override build_kv_connector_stats,
#
Exampl
eConnector doesn't override build_kv_connector_stats,
# so it returns None and should be skipped
serialized_data
=
{
"NixlConnector"
:
{
...
...
@@ -440,7 +440,7 @@ class TestMultiConnectorStats:
"num_failed_notifications"
:
[],
}
},
"
SharedStorag
eConnector"
:
{
"data"
:
{
"some_field"
:
[
1
,
2
,
3
]}},
"
Exampl
eConnector"
:
{
"data"
:
{
"some_field"
:
[
1
,
2
,
3
]}},
}
stats
=
MultiConnector
.
build_kv_connector_stats
(
data
=
serialized_data
)
...
...
@@ -451,8 +451,8 @@ class TestMultiConnectorStats:
assert
len
(
stats
.
data
)
==
1
assert
"NixlConnector"
in
stats
.
data
assert
isinstance
(
stats
.
data
[
"NixlConnector"
],
NixlKVConnectorStats
)
#
SharedStorag
eConnector should be skipped (returns None)
assert
"
SharedStorag
eConnector"
not
in
stats
.
data
#
Exampl
eConnector should be skipped (returns None)
assert
"
Exampl
eConnector"
not
in
stats
.
data
def
test_build_kv_connector_stats_handles_malformed_data
(
self
):
"""Test that malformed data raises appropriate errors."""
...
...
@@ -527,13 +527,13 @@ class TestMultiConnectorStats:
)
stats2
=
MultiKVConnectorStats
(
data
=
{
"
SharedStorag
eConnector"
:
KVConnectorStats
(
data
=
{
"field"
:
[
1
,
2
]})}
data
=
{
"
Exampl
eConnector"
:
KVConnectorStats
(
data
=
{
"field"
:
[
1
,
2
]})}
)
result
=
stats1
.
aggregate
(
stats2
)
assert
"NixlConnector"
in
result
.
data
assert
"
SharedStorag
eConnector"
in
result
.
data
assert
"
Exampl
eConnector"
in
result
.
data
def
test_reduce
(
self
):
"""Test that reduce() correctly reduces all nested connector stats."""
...
...
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
8d75f22e
...
...
@@ -9,8 +9,10 @@ import textwrap
import
time
import
uuid
from
collections
import
defaultdict
from
unittest.mock
import
patch
from
typing
import
Any
from
unittest.mock
import
MagicMock
,
patch
import
msgspec
import
pytest
import
ray
import
torch
...
...
@@ -18,6 +20,7 @@ import torch
from
vllm
import
LLM
from
vllm.config
import
KVTransferConfig
from
vllm.distributed.kv_transfer.kv_connector.utils
import
KVOutputAggregator
from
vllm.distributed.kv_transfer.kv_connector.v1
import
nixl_connector
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
KVConnectorStats
from
vllm.distributed.kv_transfer.kv_connector.v1.multi_connector
import
(
MultiKVConnectorStats
,
...
...
@@ -29,13 +32,16 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorMetadata
,
NixlConnectorScheduler
,
NixlConnectorWorker
,
NixlHandshakePayload
,
NixlKVConnectorStats
,
compute_nixl_compatibility_hash
,
)
from
vllm.distributed.kv_transfer.kv_transfer_state
import
(
ensure_kv_transfer_shutdown
,
has_kv_transfer_group
,
)
from
vllm.forward_context
import
ForwardContext
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
Platform
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionBackend
...
...
@@ -317,13 +323,19 @@ def test_kv_transfer_handshake(dist_init):
}
prefill_connector
.
register_kv_caches
(
kv_caches
)
# Simulate EngineCore initialization that would
# gather connector metadata from all workers, the scheduler connector
# expects metadata to be in dict[int, KVConnectorHandshakeMetadata],
# where the first key is the dp_rank, the second key is the tp_rank.
metadata
=
{
0
:
prefill_connector
.
get_handshake_metadata
()}
# Simulate EngineCore initialization that would gather connector
# metadata from all workers
metadata
=
prefill_connector
.
get_handshake_metadata
()
# metadata is a NixlHandshakePayload, decode it to get NixlAgentMetadata
decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlAgentMetadata
)
expected_agent_metadata
=
decoder
.
decode
(
metadata
.
agent_metadata_bytes
)
# The scheduler connector expects metadata to be in
# dict[int, KVConnectorHandshakeMetadata], where the first key is
# the dp_rank, the second key is the tp_rank.
scheduler_connector
=
scheduler
.
get_kv_connector
()
scheduler_connector
.
set_xfer_handshake_metadata
(
metadata
)
scheduler_connector
.
set_xfer_handshake_metadata
(
{
0
:
metadata
}
)
# Simulate a request that finishes prefill, which returns
# corresponding NixlConnectorMetadata for decode instance.
...
...
@@ -362,9 +374,9 @@ def test_kv_transfer_handshake(dist_init):
)
received_metadata
=
mock_add_remote_agent
.
call_args
.
args
assert
received_metadata
[
0
]
==
expected_agent_metadata
assert
received_metadata
[
1
]
==
0
# remote_tp_rank
assert
received_metadata
[
2
]
==
1
# remote_tp_size
assert
metadata
[
0
]
==
received_metadata
[
0
]
# Need to shutdown the background thread to release NIXL side channel port
scheduler_connector
.
shutdown
()
...
...
@@ -403,7 +415,6 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
device_id
=
0
,
num_blocks
=
1
,
block_lens
=
self
.
block_len_per_layer
,
attn_backend_name
=
self
.
backend_name
,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout
=
"HND"
,
...
...
@@ -460,6 +471,7 @@ class TestNixlHandshake:
num_xfers
+
6
,
],
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"remote_tp_size"
:
1
,
...
...
@@ -526,6 +538,7 @@ class TestNixlHandshake:
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
"prefill-id"
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"remote_tp_size"
:
prefill_tp_size
,
...
...
@@ -581,6 +594,7 @@ class TestNixlHandshake:
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-id-
{
i
}
"
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"remote_tp_size"
:
1
,
...
...
@@ -651,7 +665,6 @@ class TestNixlHandshake:
device_id
=
0
,
num_blocks
=
1
,
block_lens
=
worker
.
block_len_per_layer
,
attn_backend_name
=
worker
.
backend_name
,
kv_cache_layout
=
mismatched_layout
,
block_size
=
worker
.
block_size
,
)
...
...
@@ -706,7 +719,6 @@ class TestNixlHandshake:
num_blocks
=
1
,
# prefill TP=1, decode TP=2, remote block_lens is double to local
block_lens
=
[
i
*
2
for
i
in
worker
.
block_len_per_layer
],
attn_backend_name
=
worker
.
backend_name
,
kv_cache_layout
=
"HND"
,
block_size
=
worker
.
block_size
,
)
...
...
@@ -746,6 +758,7 @@ def test_kv_connector_stats(dist_init):
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"remote_tp_size"
:
1
,
...
...
@@ -1099,7 +1112,26 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
llm
.
llm_engine
.
engine_core
.
shutdown
()
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
[
"FLASH_ATTN"
,
"TRITON_ATTN"
])
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
[
pytest
.
param
(
"FLASH_ATTN"
,
marks
=
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Attention backend FLASH_ATTN is not supported on ROCm"
,
),
),
pytest
.
param
(
"ROCM_ATTN"
,
marks
=
pytest
.
mark
.
skipif
(
not
current_platform
.
is_rocm
(),
reason
=
"Attention backend ROCM_ATTN is only supported on ROCm"
,
),
),
"TRITON_ATTN"
,
],
)
def
test_register_kv_caches
(
dist_init
,
attn_backend
,
monkeypatch
):
"""
Test that register_kv_caches() properly calls nixl_wrapper methods with
...
...
@@ -1121,6 +1153,10 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionBackend
backend_cls
=
FlashAttentionBackend
elif
attn_backend
==
"ROCM_ATTN"
:
from
vllm.v1.attention.backends.rocm_attn
import
RocmAttentionBackend
backend_cls
=
RocmAttentionBackend
else
:
# TRITON_ATTN
from
vllm.v1.attention.backends.triton_attn
import
TritonAttentionBackend
...
...
@@ -1139,25 +1175,43 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
}
# Store tensor info for validation
expected_tensor_size
=
shared_tensor
[
0
].
element_size
()
*
shared_tensor
[
0
].
numel
()
test_shape
=
backend_cls
.
get_kv_cache_shape
(
num_blocks
=
1
,
block_size
=
16
,
num_kv_heads
=
1
,
head_size
=
1
)
is_blocks_first
=
len
(
test_shape
)
==
5
and
test_shape
[
0
]
==
1
if
is_blocks_first
:
expected_tensor_size
=
shared_tensor
.
element_size
()
*
shared_tensor
.
numel
()
expected_base_addrs
=
[
shared_tensor
.
data_ptr
(),
unique_tensor
.
data_ptr
(),
]
expected_num_entries
=
2
else
:
expected_tensor_size
=
(
shared_tensor
[
0
].
element_size
()
*
shared_tensor
[
0
].
numel
()
)
expected_base_addrs
=
[
shared_tensor
[
0
].
data_ptr
(),
shared_tensor
[
1
].
data_ptr
(),
unique_tensor
[
0
].
data_ptr
(),
unique_tensor
[
1
].
data_ptr
(),
]
expected_num_entries
=
4
nixl_module
=
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector"
with
(
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
)
as
mock_nixl_wrapper
,
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"
),
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"
)
as
mock_thread
,
):
# noqa: E501
patch
(
f
"
{
nixl_module
}
.NixlWrapper"
)
as
mock_nixl_wrapper
,
patch
(
f
"
{
nixl_module
}
.threading.Event"
),
patch
(
f
"
{
nixl_module
}
.threading.Thread"
)
as
mock_thread
,
patch
(
f
"
{
nixl_module
}
.get_attn_backend"
)
as
mock_get_attn_backend
,
):
# Ensure get_attn_backend returns the correct value due to
# _cached_get_attn_backend returning the backend from previous
# test run if not mocking.
mock_get_attn_backend
.
return_value
=
backend_cls
# Create connector
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
...
...
@@ -1168,6 +1222,9 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
mock_wrapper_instance
=
mock_nixl_wrapper
.
return_value
connector
.
connector_worker
.
nixl_wrapper
=
mock_wrapper_instance
# Appease NixlHandshakePayload encoding with some bytes
mock_wrapper_instance
.
get_agent_metadata
.
return_value
=
b
"fake_agent_metadata"
# Reassure the shutdown() check that the thread is terminated
mock_thread
.
return_value
.
is_alive
.
return_value
=
False
...
...
@@ -1177,7 +1234,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
# Verify get_reg_descs was called with caches_data
assert
mock_wrapper_instance
.
get_reg_descs
.
called
caches_data
,
_
=
mock_wrapper_instance
.
get_reg_descs
.
call_args
[
0
]
assert
len
(
caches_data
)
==
4
assert
len
(
caches_data
)
==
expected_num_entries
for
i
,
cache_entry
in
enumerate
(
caches_data
):
base_addr
,
size
,
_tp_rank
,
_
=
cache_entry
...
...
@@ -1199,7 +1256,12 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
f
"Expected
{
expected_blocks_count
}
blocks, got
{
len
(
blocks_data
)
}
"
)
expected_block_len
=
expected_tensor_size
//
2
num_blocks
=
2
if
is_blocks_first
:
expected_block_len
=
expected_tensor_size
//
num_blocks
//
2
else
:
expected_block_len
=
expected_tensor_size
//
num_blocks
for
i
,
block_entry
in
enumerate
(
blocks_data
):
block_start_addr
,
block_len
,
tp_rank
=
block_entry
assert
block_len
==
expected_block_len
,
(
...
...
@@ -1296,7 +1358,7 @@ def test_shutdown_cleans_up_resources(dist_init):
patch
.
object
(
nixl_wrapper
,
"remove_remote_agent"
)
as
mock_rem_agent
,
patch
.
object
(
nixl_wrapper
,
"deregister_memory"
)
as
mock_dereg
,
):
worker
.
_recving_transfers
=
{
"req1"
:
[
(
123
,
time
.
perf_counter
())
]}
worker
.
_recving_transfers
=
{
"req1"
:
[
123
]}
worker
.
src_xfer_side_handle
=
456
worker
.
dst_xfer_side_handles
=
{
"engine1"
:
789
}
worker
.
_remote_agents
=
{
"engine1"
:
{
0
:
"agent1"
}}
...
...
@@ -1459,6 +1521,7 @@ def test_handshake_failure_returns_finished(dist_init):
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"remote_tp_size"
:
1
,
...
...
@@ -1508,6 +1571,7 @@ def test_transfer_setup_failure_returns_finished(dist_init):
kv_transfer_params
=
{
"remote_block_ids"
:
[
10
,
11
,
12
],
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_request_id"
:
f
"prefill-
{
request_id
}
"
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"remote_tp_size"
:
1
,
...
...
@@ -1534,3 +1598,194 @@ def test_transfer_setup_failure_returns_finished(dist_init):
# ensure request appears in get_finished
_
,
done_recving
=
connector
.
get_finished
(
finished_req_ids
=
set
())
assert
request_id
in
done_recving
@
pytest
.
mark
.
parametrize
(
"mismatch_type,config_overrides,version_override,should_fail,enforce_handshake_compat"
,
[
(
"vllm_version"
,
{},
{
"vllm_version"
:
"0.6.1"
},
True
,
True
),
(
"nixl_connector_version"
,
{},
{
"connector_version"
:
37
},
True
,
True
),
(
"model_name"
,
{
"model"
:
"facebook/opt-350m"
},
{},
True
,
True
),
(
"dtype"
,
{
"dtype"
:
"bfloat16"
},
{},
True
,
True
),
(
"cache_dtype"
,
{
"cache_dtype"
:
"fp8"
},
{},
True
,
True
),
(
"num_kv_heads"
,
{
"hf_overrides"
:
{
"num_key_value_heads"
:
8
}},
{},
True
,
True
),
(
"num_hidden_layers"
,
{
"hf_overrides"
:
{
"num_hidden_layers"
:
24
}},
{},
True
,
True
,
),
(
"hidden_size"
,
{
"hf_overrides"
:
{
"hidden_size"
:
1536
}},
{},
True
,
True
),
(
"block_size"
,
{
"block_size"
:
8
},
{},
False
,
True
),
(
"matching_config"
,
{},
{},
False
,
True
),
(
"escape_hatch"
,
{
"model"
:
"facebook/opt-350m"
},
{},
False
,
False
),
],
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
,
)
def
test_compatibility_hash_validation
(
dist_init
,
mismatch_type
,
config_overrides
,
version_override
,
should_fail
,
enforce_handshake_compat
,
):
"""
Test NIXL compatibility hash validation during handshake.
Parameters:
mismatch_type: description of what is being tested
config_overrides: dict of config to override for the remote instance
version_override: version dict e.g. {"vllm_version": "0.6.1"}
should_fail: whether the handshake should fail
enforce_handshake_compat: whether to enforce compatibility checking
"""
local_vllm_config
=
create_vllm_config
(
model
=
"facebook/opt-125m"
,
block_size
=
16
,
kv_connector_extra_config
=
{
"enforce_handshake_compat"
:
enforce_handshake_compat
},
)
decode_connector
=
NixlConnector
(
local_vllm_config
,
KVConnectorRole
.
WORKER
)
decode_worker
=
decode_connector
.
connector_worker
remote_config_params
:
dict
[
str
,
Any
]
=
{
"model"
:
"facebook/opt-125m"
,
"block_size"
:
16
,
**
config_overrides
,
}
remote_vllm_config
=
create_vllm_config
(
**
remote_config_params
)
with
contextlib
.
ExitStack
()
as
stack
:
if
"vllm_version"
in
version_override
:
stack
.
enter_context
(
patch
(
"vllm.__version__"
,
version_override
[
"vllm_version"
])
)
elif
"connector_version"
in
version_override
:
stack
.
enter_context
(
patch
.
object
(
nixl_connector
,
"NIXL_CONNECTOR_VERSION"
,
version_override
[
"connector_version"
],
)
)
remote_hash
=
compute_nixl_compatibility_hash
(
remote_vllm_config
,
decode_worker
.
backend_name
)
prefill_block_size
=
config_overrides
.
get
(
"block_size"
,
16
)
prefill_metadata
=
NixlAgentMetadata
(
engine_id
=
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
agent_metadata
=
FakeNixlWrapper
.
AGENT_METADATA
,
kv_caches_base_addr
=
[
0
],
device_id
=
0
,
num_blocks
=
1
,
block_lens
=
[
4096
*
prefill_block_size
],
# slot_size * block_size
kv_cache_layout
=
"HND"
,
block_size
=
prefill_block_size
,
)
handshake_payload
=
NixlHandshakePayload
(
compatibility_hash
=
remote_hash
,
agent_metadata_bytes
=
msgspec
.
msgpack
.
encode
(
prefill_metadata
),
)
# Mock ZMQ socket to return our handshake payload
mock_socket
=
MagicMock
()
mock_socket
.
recv
.
return_value
=
msgspec
.
msgpack
.
encode
(
handshake_payload
)
# Mock add_remote_agent to avoid actual NIXL operations
# Patch zmq_ctx to return our mock socket
with
(
patch
.
object
(
decode_worker
,
"add_remote_agent"
,
return_value
=
"fake_agent"
),
patch
.
object
(
nixl_connector
,
"zmq_ctx"
)
as
mock_zmq_ctx
,
):
mock_zmq_ctx
.
return_value
.
__enter__
.
return_value
=
mock_socket
if
should_fail
:
with
pytest
.
raises
(
RuntimeError
,
match
=
"compatibility hash mismatch"
):
decode_worker
.
_nixl_handshake
(
host
=
"localhost"
,
port
=
1234
,
remote_tp_size
=
1
,
expected_engine_id
=
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
)
else
:
result
=
decode_worker
.
_nixl_handshake
(
host
=
"localhost"
,
port
=
1234
,
remote_tp_size
=
1
,
expected_engine_id
=
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
)
# Verify handshake returned agent mapping
assert
isinstance
(
result
,
dict
)
assert
len
(
result
)
==
1
@
pytest
.
mark
.
parametrize
(
"error_scenario"
,
[
"handshake_decode_error"
,
"handshake_validation_error"
,
"metadata_decode_error"
,
"metadata_validation_error"
,
],
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
,
)
def
test_handshake_decode_errors
(
dist_init
,
error_scenario
):
"""
Test that msgspec decode errors are properly handled during handshake.
Tests both DecodeError and ValidationError for both decoders:
- NixlHandshakePayload decoder
- NixlAgentMetadata decoder
"""
local_vllm_config
=
create_vllm_config
(
model
=
"facebook/opt-125m"
,
block_size
=
16
,
)
decode_connector
=
NixlConnector
(
local_vllm_config
,
KVConnectorRole
.
WORKER
)
decode_worker
=
decode_connector
.
connector_worker
if
error_scenario
==
"handshake_decode_error"
:
msg_bytes
=
b
"this is not valid msgpack data"
elif
error_scenario
==
"handshake_validation_error"
:
msg_bytes
=
msgspec
.
msgpack
.
encode
({
"wrong_field"
:
"value"
})
elif
error_scenario
==
"metadata_decode_error"
:
valid_handshake
=
NixlHandshakePayload
(
compatibility_hash
=
decode_worker
.
compat_hash
,
agent_metadata_bytes
=
b
"invalid msgpack for metadata"
,
)
msg_bytes
=
msgspec
.
msgpack
.
encode
(
valid_handshake
)
elif
error_scenario
==
"metadata_validation_error"
:
valid_handshake
=
NixlHandshakePayload
(
compatibility_hash
=
decode_worker
.
compat_hash
,
agent_metadata_bytes
=
msgspec
.
msgpack
.
encode
({
"missing"
:
"fields"
}),
)
msg_bytes
=
msgspec
.
msgpack
.
encode
(
valid_handshake
)
else
:
raise
AssertionError
(
f
"
{
error_scenario
}
not a valid scenario"
)
mock_socket
=
MagicMock
()
mock_socket
.
recv
.
return_value
=
msg_bytes
with
(
patch
.
object
(
decode_worker
,
"add_remote_agent"
,
return_value
=
"fake_agent"
),
patch
.
object
(
nixl_connector
,
"zmq_ctx"
)
as
mock_zmq_ctx
,
):
mock_zmq_ctx
.
return_value
.
__enter__
.
return_value
=
mock_socket
with
pytest
.
raises
(
RuntimeError
):
decode_worker
.
_nixl_handshake
(
host
=
"localhost"
,
port
=
1234
,
remote_tp_size
=
1
,
expected_engine_id
=
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
)
tests/v1/kv_connector/unit/utils.py
View file @
8d75f22e
...
...
@@ -24,8 +24,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata
,
KVConnectorRole
,
)
from
vllm.distributed.kv_transfer.kv_connector.v1.
shared_storag
e_connector
import
(
# noqa
SharedStorag
eConnector
,
from
vllm.distributed.kv_transfer.kv_connector.v1.
exampl
e_connector
import
(
# noqa
Exampl
eConnector
,
)
from
vllm.utils.hashing
import
sha256
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
...
...
@@ -90,13 +90,25 @@ def create_vllm_config(
max_model_len
:
int
=
10000
,
enable_chunked_prefill
:
bool
=
True
,
enable_permute_local_kv
:
bool
=
False
,
kv_connector_extra_config
:
dict
[
str
,
Any
]
|
None
=
None
,
dtype
:
str
=
"float16"
,
cache_dtype
:
str
=
"auto"
,
hf_overrides
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
VllmConfig
:
"""Initialize VllmConfig For Testing."""
model_config
=
ModelConfig
(
model
=
model
,
trust_remote_code
=
True
,
dtype
=
"float16"
,
dtype
=
dtype
,
seed
=
42
,
hf_overrides
=
hf_overrides
or
{},
)
scheduler_config
=
SchedulerConfig
(
max_num_seqs
=
max_num_seqs
,
max_num_batched_tokens
=
max_num_batched_tokens
,
max_model_len
=
max_model_len
,
enable_chunked_prefill
=
enable_chunked_prefill
,
is_encoder_decoder
=
model_config
.
is_encoder_decoder
,
)
scheduler_config
=
SchedulerConfig
(
max_num_seqs
=
max_num_seqs
,
...
...
@@ -110,13 +122,14 @@ def create_vllm_config(
block_size
=
block_size
,
gpu_memory_utilization
=
0.9
,
swap_space
=
0
,
cache_dtype
=
"auto"
,
cache_dtype
=
cache_dtype
,
enable_prefix_caching
=
True
,
)
kv_transfer_config
=
KVTransferConfig
(
kv_connector
=
"NixlConnector"
,
kv_role
=
"kv_both"
,
enable_permute_local_kv
=
enable_permute_local_kv
,
kv_connector_extra_config
=
kv_connector_extra_config
or
{},
)
return
VllmConfig
(
scheduler_config
=
scheduler_config
,
...
...
@@ -188,6 +201,7 @@ def create_request(
do_remote_prefill
=
True
,
do_remote_decode
=
False
,
remote_engine_id
=
"my-engine-id"
,
remote_request_id
=
f
"prefill-
{
request_id
}
"
,
remote_block_ids
=
list
(
range
(
num_remote_blocks
)),
remote_host
=
"my-host"
,
remote_port
=
1234
,
...
...
@@ -257,10 +271,10 @@ def create_model_runner_output(
)
class
Test
SharedStorageConnector
(
SharedStorag
eConnector
):
class
Test
ExampleConnector
(
Exampl
eConnector
):
def
__init__
(
self
,
config
:
VllmConfig
,
role
,
kv_cache_config
):
self
.
name
=
config
.
kv_transfer_config
.
kv_connector_extra_config
[
"name"
]
self
.
_connector
=
SharedStorag
eConnector
(
config
,
role
)
self
.
_connector
=
Exampl
eConnector
(
config
,
role
)
self
.
call_record
:
dict
[
str
,
int
]
=
defaultdict
(
int
)
# Use a unique temp file per connector
self
.
_event_file
=
(
...
...
@@ -387,7 +401,7 @@ class MockKVConnector(KVConnectorBase_V1):
KVConnectorFactory
.
register_connector
(
"Test
SharedStorag
eConnector"
,
__name__
,
Test
SharedStorag
eConnector
.
__name__
"Test
Exampl
eConnector"
,
__name__
,
Test
Exampl
eConnector
.
__name__
)
KVConnectorFactory
.
register_connector
(
...
...
Prev
1
…
12
13
14
15
16
17
18
19
20
…
34
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment