Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
de35c06c
Unverified
Commit
de35c06c
authored
Mar 17, 2026
by
Yong Hoon Shin
Committed by
GitHub
Mar 17, 2026
Browse files
Make KV connector metadata build overridable via plugin (#37336)
Signed-off-by:
Yong Hoon Shin
<
yhshin@meta.com
>
parent
c0745a85
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
142 additions
and
4 deletions
+142
-4
tests/v1/core/utils.py
tests/v1/core/utils.py
+6
-1
tests/v1/kv_connector/unit/test_scheduler_kv_connector_override.py
...kv_connector/unit/test_scheduler_kv_connector_override.py
+130
-0
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+6
-3
No files found.
tests/v1/core/utils.py
View file @
de35c06c
...
@@ -47,7 +47,7 @@ def create_scheduler(
...
@@ -47,7 +47,7 @@ def create_scheduler(
enable_prefix_caching
:
bool
=
False
,
enable_prefix_caching
:
bool
=
False
,
long_prefill_token_threshold
:
int
=
0
,
long_prefill_token_threshold
:
int
=
0
,
disable_chunked_mm_input
:
bool
=
False
,
disable_chunked_mm_input
:
bool
=
False
,
use_kv_connector
:
None
|
bool
|
MockKVConfig
=
None
,
use_kv_connector
:
None
|
bool
|
str
|
MockKVConfig
=
None
,
num_blocks
:
int
=
10000
,
num_blocks
:
int
=
10000
,
block_size
:
int
=
16
,
block_size
:
int
=
16
,
max_model_len
:
int
|
None
=
None
,
max_model_len
:
int
|
None
=
None
,
...
@@ -107,6 +107,11 @@ def create_scheduler(
...
@@ -107,6 +107,11 @@ def create_scheduler(
"is_async"
:
use_kv_connector
.
is_async
,
"is_async"
:
use_kv_connector
.
is_async
,
},
},
)
)
elif
isinstance
(
use_kv_connector
,
str
):
kv_transfer_config
=
KVTransferConfig
(
kv_connector
=
use_kv_connector
,
kv_role
=
"kv_both"
,
)
elif
use_kv_connector
:
elif
use_kv_connector
:
kv_transfer_config
=
KVTransferConfig
(
kv_transfer_config
=
KVTransferConfig
(
kv_connector
=
"ExampleConnector"
,
kv_connector
=
"ExampleConnector"
,
...
...
tests/v1/kv_connector/unit/test_scheduler_kv_connector_override.py
0 → 100644
View file @
de35c06c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
unittest.mock
import
MagicMock
,
patch
import
pytest
import
vllm.plugins
as
plugins_module
from
tests.v1.core.utils
import
create_requests
,
create_scheduler
from
vllm.distributed.kv_transfer.kv_connector.factory
import
(
KVConnectorFactory
,
)
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorMetadata
,
)
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.core.kv_cache_utils
import
BlockHash
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.scheduler
import
Scheduler
from
vllm.v1.request
import
Request
class
DummyConnectorMetadata
(
KVConnectorMetadata
):
def
__init__
(
self
,
block_hashes_by_req
:
dict
[
str
,
list
[
BlockHash
]]):
self
.
block_hashes_by_req
=
block_hashes_by_req
class
DummyKVConnector
(
KVConnectorBase_V1
):
def
__init__
(
self
,
vllm_config
,
role
,
kv_cache_config
=
None
):
super
().
__init__
(
vllm_config
,
role
,
kv_cache_config
)
def
get_num_new_matched_tokens
(
self
,
request
:
Request
,
num_computed_tokens
:
int
)
->
tuple
[
int
|
None
,
bool
]:
return
(
0
,
False
)
def
update_state_after_alloc
(
self
,
request
:
Request
,
blocks
:
KVCacheBlocks
,
num_external_tokens
:
int
):
pass
def
build_connector_meta
(
self
,
scheduler_output
:
SchedulerOutput
)
->
KVConnectorMetadata
:
block_hashes_by_req
=
getattr
(
scheduler_output
,
"block_hashes_by_req"
,
None
)
assert
block_hashes_by_req
is
not
None
,
(
"DummyKVConnector expected 'block_hashes_by_req' on scheduler_output"
)
return
DummyConnectorMetadata
(
block_hashes_by_req
=
block_hashes_by_req
,
)
def
start_load_kv
(
self
,
kv_caches
,
finished_req_ids
):
pass
def
wait_for_layer_load
(
self
,
layer_name
):
pass
def
save_kv_layer
(
self
,
layer_name
,
kv_layer
,
attn_metadata
,
**
kwargs
):
pass
def
wait_for_save
(
self
):
pass
def
_my_plugin
():
"""Registers the dummy KV connector and overrides _build_kv_connector_meta"""
KVConnectorFactory
.
register_connector
(
"DummyKVConnector"
,
__name__
,
DummyKVConnector
.
__name__
,
)
def
_custom_build_kv_connector_meta
(
self
,
connector
:
KVConnectorBase_V1
,
scheduler_output
:
SchedulerOutput
)
->
KVConnectorMetadata
:
block_hashes_by_req
:
dict
[
str
,
list
[
BlockHash
]]
=
{}
for
req_id
in
scheduler_output
.
num_scheduled_tokens
:
request
=
self
.
requests
[
req_id
]
block_hashes_by_req
[
req_id
]
=
request
.
block_hashes
scheduler_output
.
block_hashes_by_req
=
block_hashes_by_req
# type: ignore[attr-defined]
return
connector
.
build_connector_meta
(
scheduler_output
)
Scheduler
.
_build_kv_connector_meta
=
_custom_build_kv_connector_meta
@
pytest
.
fixture
def
_load_plugin
():
"""Load the fake plugin through the real load_general_plugins() path."""
ep
=
MagicMock
()
ep
.
name
=
"dummy_kv_connector_plugin"
ep
.
value
=
f
"
{
__name__
}
:_my_plugin"
ep
.
load
.
return_value
=
_my_plugin
# Reset the global guard so load_general_plugins() actually runs.
plugins_module
.
plugins_loaded
=
False
with
patch
(
"importlib.metadata.entry_points"
,
return_value
=
[
ep
]):
plugins_module
.
load_general_plugins
()
yield
# Reset again so other tests are not affected.
plugins_module
.
plugins_loaded
=
False
def
test_connector_receives_block_hashes
(
_load_plugin
):
block_size
=
16
num_tokens
=
48
# 3 full blocks worth of tokens
scheduler
=
create_scheduler
(
use_kv_connector
=
"DummyKVConnector"
,
block_size
=
block_size
)
requests
=
create_requests
(
num_requests
=
3
,
num_tokens
=
num_tokens
,
block_size
=
block_size
)
for
req
in
requests
:
scheduler
.
add_request
(
req
)
output
=
scheduler
.
schedule
()
# Verify the connector metadata was built with block hashes.
meta
=
output
.
kv_connector_metadata
assert
isinstance
(
meta
,
DummyConnectorMetadata
)
assert
len
(
meta
.
block_hashes_by_req
)
==
3
for
req
in
requests
:
assert
req
.
request_id
in
meta
.
block_hashes_by_req
# Each request has num_tokens / block_size = 3 full block hashes.
assert
len
(
meta
.
block_hashes_by_req
[
req
.
request_id
])
==
(
num_tokens
//
block_size
)
assert
meta
.
block_hashes_by_req
[
req
.
request_id
]
==
req
.
block_hashes
vllm/v1/core/sched/scheduler.py
View file @
de35c06c
...
@@ -910,9 +910,7 @@ class Scheduler(SchedulerInterface):
...
@@ -910,9 +910,7 @@ class Scheduler(SchedulerInterface):
# 2. Wrap up all the KV cache load / save ops into an opaque object
# 2. Wrap up all the KV cache load / save ops into an opaque object
# 3. Clear the internal states of the connector
# 3. Clear the internal states of the connector
if
self
.
connector
is
not
None
:
if
self
.
connector
is
not
None
:
meta
:
KVConnectorMetadata
=
self
.
connector
.
build_connector_meta
(
meta
=
self
.
_build_kv_connector_meta
(
self
.
connector
,
scheduler_output
)
scheduler_output
)
scheduler_output
.
kv_connector_metadata
=
meta
scheduler_output
.
kv_connector_metadata
=
meta
# Build the connector meta for ECConnector
# Build the connector meta for ECConnector
...
@@ -926,6 +924,11 @@ class Scheduler(SchedulerInterface):
...
@@ -926,6 +924,11 @@ class Scheduler(SchedulerInterface):
self
.
_update_after_schedule
(
scheduler_output
)
self
.
_update_after_schedule
(
scheduler_output
)
return
scheduler_output
return
scheduler_output
def
_build_kv_connector_meta
(
self
,
connector
:
KVConnectorBase_V1
,
scheduler_output
:
SchedulerOutput
)
->
KVConnectorMetadata
:
return
connector
.
build_connector_meta
(
scheduler_output
)
def
_preempt_request
(
self
,
request
:
Request
,
timestamp
:
float
)
->
None
:
def
_preempt_request
(
self
,
request
:
Request
,
timestamp
:
float
)
->
None
:
"""Preempt a request and put it back to the waiting queue.
"""Preempt a request and put it back to the waiting queue.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment