Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
378385b9
Unverified
Commit
378385b9
authored
Jan 22, 2026
by
knlnguyen1802
Committed by
GitHub
Jan 22, 2026
Browse files
[EC Connector] Optimize remote cache check in scheduler (#32585)
Signed-off-by:
knlnguyen1802
<
knlnguyen1802@gmail.com
>
parent
c5487e2b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
76 additions
and
57 deletions
+76
-57
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+24
-16
tests/v1/ec_connector/unit/test_ec_example_connector.py
tests/v1/ec_connector/unit/test_ec_example_connector.py
+31
-17
vllm/distributed/ec_transfer/ec_connector/base.py
vllm/distributed/ec_transfer/ec_connector/base.py
+7
-7
vllm/distributed/ec_transfer/ec_connector/example_connector.py
...distributed/ec_transfer/ec_connector/example_connector.py
+7
-10
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+7
-7
No files found.
tests/v1/core/test_scheduler.py
View file @
378385b9
...
...
@@ -2560,15 +2560,14 @@ def test_ec_connector_cache_hit_external_load(use_kv_connector):
mm_positions
=
mm_positions
,
)[
0
]
# Mock cache hit - encoder cache exists externally
scheduler
.
ec_connector
.
has_cache
s
=
Mock
(
return_value
=
[
True
]
)
# Mock cache hit - encoder cache
has_
exists externally
scheduler
.
ec_connector
.
has_cache
_item
=
Mock
(
return_value
=
True
)
scheduler
.
ec_connector
.
update_state_after_alloc
=
Mock
(
wraps
=
scheduler
.
ec_connector
.
update_state_after_alloc
)
scheduler
.
add_request
(
request
)
output
=
scheduler
.
schedule
()
# Should schedule prompt tokens
scheduled_tokens
=
output
.
num_scheduled_tokens
[
request
.
request_id
]
assert
scheduled_tokens
==
NUM_TOKENS
...
...
@@ -2611,7 +2610,7 @@ def test_ec_connector_cache_miss_computes_locally(use_kv_connector):
)[
0
]
# Mock cache miss - encoder cache doesn't exist externally
scheduler
.
ec_connector
.
has_cache
s
=
Mock
(
return_value
=
[
False
]
)
scheduler
.
ec_connector
.
has_cache
_item
=
Mock
(
return_value
=
False
)
scheduler
.
add_request
(
request_mm_missed
)
output
=
scheduler
.
schedule
()
...
...
@@ -2665,7 +2664,7 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
PlaceholderRange
(
offset
=
250
,
length
=
NUM_ENCODER_TOKENS_1
),
]
]
has_cache_item_result_map_1
=
{
"hash1_A"
:
False
,
"hash1_B"
:
True
,
"hash1_F"
:
True
}
# Create request with 4 MM items, with 2 identical items
request1
=
create_requests
(
num_requests
=
1
,
...
...
@@ -2676,7 +2675,9 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
)[
0
]
# Mock partial cache hit: 1st and 3rd missing, 2nd and 4th exist
scheduler
.
ec_connector
.
has_caches
=
Mock
(
return_value
=
[
False
,
True
,
False
,
True
])
scheduler
.
ec_connector
.
has_cache_item
=
Mock
(
side_effect
=
lambda
hash_val
:
has_cache_item_result_map_1
[
hash_val
]
)
scheduler
.
ec_connector
.
update_state_after_alloc
=
Mock
(
wraps
=
scheduler
.
ec_connector
.
update_state_after_alloc
)
...
...
@@ -2736,7 +2737,12 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
PlaceholderRange
(
offset
=
250
,
length
=
NUM_ENCODER_TOKENS_2
),
]
]
has_cache_item_result_map_2
=
{
"hash1_C"
:
True
,
"hash1_D"
:
False
,
"hash1_E"
:
False
,
"hash1_A"
:
True
,
}
request2
=
create_requests
(
num_requests
=
1
,
num_tokens
=
NUM_TOKENS_2
,
...
...
@@ -2746,7 +2752,9 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
)[
0
]
# Mock partial cache hit: only hash1_A and hash1_C exist in connector
scheduler
.
ec_connector
.
has_caches
=
Mock
(
return_value
=
[
True
,
False
,
False
,
True
])
scheduler
.
ec_connector
.
has_cache_item
=
Mock
(
side_effect
=
lambda
hash_val
:
has_cache_item_result_map_2
[
hash_val
]
)
scheduler
.
add_request
(
request2
)
output
=
scheduler
.
schedule
()
...
...
@@ -2821,9 +2829,9 @@ def test_ec_connector_schedule_multiple_requests(cache_exist, use_kv_connector):
if
cache_exist
==
"connector_only"
:
# Cache exist in ec_connector
scheduler
.
ec_connector
.
has_cache
s
=
Mock
(
return_value
=
[
True
]
)
scheduler
.
ec_connector
.
has_cache
_item
=
Mock
(
return_value
=
True
)
elif
cache_exist
==
"no_where"
:
scheduler
.
ec_connector
.
has_cache
s
=
Mock
(
return_value
=
[
False
]
)
scheduler
.
ec_connector
.
has_cache
_item
=
Mock
(
return_value
=
False
)
output
=
scheduler
.
schedule
()
assert
len
(
output
.
scheduled_new_reqs
)
==
len
(
requests
)
...
...
@@ -2887,7 +2895,7 @@ def test_ec_connector_unable_to_allocate(use_kv_connector):
)
# Mock ec_connector load external cache behavior
scheduler
.
ec_connector
.
has_cache
s
=
Mock
(
return_value
=
[
True
]
)
scheduler
.
ec_connector
.
has_cache
_item
=
Mock
(
return_value
=
True
)
scheduler
.
ec_connector
.
update_state_after_alloc
=
Mock
(
wraps
=
scheduler
.
ec_connector
.
update_state_after_alloc
)
...
...
@@ -2984,7 +2992,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
)
# Mock cache hit: Both cache exist in connector (at E->PD initially)
scheduler
.
ec_connector
.
has_cache
s
=
Mock
(
return_value
=
[
True
]
)
scheduler
.
ec_connector
.
has_cache
_item
=
Mock
(
return_value
=
True
)
scheduler
.
ec_connector
.
update_state_after_alloc
=
Mock
(
wraps
=
scheduler
.
ec_connector
.
update_state_after_alloc
)
...
...
@@ -3139,9 +3147,9 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
if
cache_exist
==
"connector_only"
:
# Cache exist in ec_connector
scheduler
.
ec_connector
.
has_cache
s
=
Mock
(
return_value
=
[
True
]
)
scheduler
.
ec_connector
.
has_cache
_item
=
Mock
(
return_value
=
True
)
elif
cache_exist
==
"no_where"
:
scheduler
.
ec_connector
.
has_cache
s
=
Mock
(
return_value
=
[
False
]
)
scheduler
.
ec_connector
.
has_cache
_item
=
Mock
(
return_value
=
False
)
# 4th Schedule - this should trigger req_low resumption from waiting
output
=
scheduler
.
schedule
()
...
...
@@ -3259,8 +3267,8 @@ def test_ec_connector_allocate_encoder_tokens_with_external_load(use_kv_connecto
)[
0
]
# Mock cache hit: MM of request1 NOT cached remotely, request2 cached remotely
scheduler
.
ec_connector
.
has_cache
s
=
Mock
(
side_effect
=
lambda
req
:
[
True
,
True
,
Tr
ue
]
i
f
req
==
request2
else
[
False
]
scheduler
.
ec_connector
.
has_cache
_item
=
Mock
(
side_effect
=
lambda
hash_value
:
hash_val
ue
i
n
mm_hashes_list_2
[
0
]
)
scheduler
.
ec_connector
.
update_state_after_alloc
=
Mock
(
wraps
=
scheduler
.
ec_connector
.
update_state_after_alloc
...
...
tests/v1/ec_connector/unit/test_ec_example_connector.py
View file @
378385b9
...
...
@@ -123,15 +123,15 @@ class TestECExampleConnectorBasics:
class
TestCacheExistence
:
"""Test cache existence checking using has_cache
s
() API."""
"""Test cache existence checking using has_cache
_item
() API."""
def
test_has_cache
s
_all_exist_3_items
(
def
test_has_cache
_item
_all_exist_3_items
(
self
,
mock_vllm_config_producer
,
mock_vllm_config_consumer
,
mock_request_with_3_mm
,
):
"""Test has_cache
s
returns True when all 3 caches exist."""
"""Test has_cache
_item
returns True when all 3 caches exist."""
# Test for producer first
producer
=
ECExampleConnector
(
vllm_config
=
mock_vllm_config_producer
,
...
...
@@ -146,8 +146,11 @@ class TestCacheExistence:
encoder_cache
[
mm_hash
]
=
torch
.
randn
(
10
,
768
)
producer
.
save_caches
(
encoder_cache
,
mm_hash
)
# Test using has_caches API
producer_result
=
producer
.
has_caches
(
mock_request_with_3_mm
)
# Test using has_cache_item API
producer_result
=
[
producer
.
has_cache_item
(
mm_feature
.
identifier
)
for
mm_feature
in
mock_request_with_3_mm
.
mm_features
]
# Assert
assert
len
(
producer_result
)
==
3
...
...
@@ -159,14 +162,17 @@ class TestCacheExistence:
role
=
ECConnectorRole
.
SCHEDULER
,
)
# Test using has_caches API
consumer_result
=
consumer
.
has_caches
(
mock_request_with_3_mm
)
# Test using has_cache_item API
consumer_result
=
[
consumer
.
has_cache_item
(
mm_feature
.
identifier
)
for
mm_feature
in
mock_request_with_3_mm
.
mm_features
]
# Assert
assert
len
(
consumer_result
)
==
3
assert
all
(
consumer_result
),
f
"Expected all True, got
{
consumer_result
}
"
def
test_has_cache
s
_none_exist
(
def
test_has_cache
_item
_none_exist
(
self
,
mock_vllm_config_producer
,
mock_request_with_3_mm
):
"""Test has_caches returns False when no caches exist."""
...
...
@@ -176,13 +182,16 @@ class TestCacheExistence:
)
# Test without creating any files
result
=
connector
.
has_caches
(
mock_request_with_3_mm
)
result
=
[
connector
.
has_cache_item
(
mm_feature
.
identifier
)
for
mm_feature
in
mock_request_with_3_mm
.
mm_features
]
# Assert
assert
len
(
result
)
==
3
assert
not
any
(
result
),
f
"Expected all False, got
{
result
}
"
def
test_has_cache
s
_partial_exist
(
def
test_has_cache
_item
_partial_exist
(
self
,
mock_vllm_config_producer
,
mock_request_with_3_mm
):
"""Test has_caches with some caches existing (1 of 3)."""
...
...
@@ -197,7 +206,10 @@ class TestCacheExistence:
connector
.
save_caches
(
encoder_cache
,
mm_hash_second
)
# Test
result
=
connector
.
has_caches
(
mock_request_with_3_mm
)
result
=
[
connector
.
has_cache_item
(
mm_feature
.
identifier
)
for
mm_feature
in
mock_request_with_3_mm
.
mm_features
]
# Assert
assert
len
(
result
)
==
3
...
...
@@ -323,8 +335,11 @@ class TestCacheSaving:
encoder_cache
[
mm_hash
]
=
torch
.
randn
(
10
,
768
)
connector
.
save_caches
(
encoder_cache
,
mm_hash
)
# Verify all files exist using has_caches
result
=
connector
.
has_caches
(
mock_request_with_3_mm
)
# Verify all files exist using has_cache_item
result
=
[
connector
.
has_cache_item
(
mm_feature
.
identifier
)
for
mm_feature
in
mock_request_with_3_mm
.
mm_features
]
assert
all
(
result
),
f
"Not all caches were saved:
{
result
}
"
# Verify each file's content
...
...
@@ -347,10 +362,9 @@ class TestCacheSaving:
# Save should not raise but also not create file
connector
.
save_caches
(
encoder_cache
,
mm_hash
)
# Verify file doesn't exist using has_caches
mock_request
=
MockRequest
(
"req_consumer"
,
[
mm_hash
],
[
10
])
result
=
connector
.
has_caches
(
mock_request
)
assert
not
result
[
0
],
"Consumer should not save caches"
# Verify file doesn't exist using has_cache_item
result
=
connector
.
has_cache_item
(
mm_hash
)
assert
not
result
,
"Consumer should not save caches"
class
TestCacheLoading
:
...
...
vllm/distributed/ec_transfer/ec_connector/base.py
View file @
378385b9
...
...
@@ -182,19 +182,19 @@ class ECConnectorBase(ABC):
# ==============================
@
abstractmethod
def
has_cache
s
(
def
has_cache
_item
(
self
,
request
:
"Reque
st
"
,
)
->
list
[
bool
]
:
identifier
:
st
r
,
)
->
bool
:
"""
Check if encoder cache exists
for each mm data of requests
Check if
a single
encoder cache exists
Args:
request (Reque
st): the
request object
.
identifier (
st
r
): the
identifier of the media
.
Returns:
A
list
bool where
ith
value is True if cache exist for
i
th m
m_data of requests
A bool where value is True if cache exist for
th
e
m
edia
"""
pass
...
...
vllm/distributed/ec_transfer/ec_connector/example_connector.py
View file @
378385b9
...
...
@@ -117,23 +117,20 @@ class ECExampleConnector(ECConnectorBase):
safetensors
.
torch
.
save_file
(
tensors
,
filename
)
logger
.
debug
(
"Save cache successful for mm_hash %s"
,
mm_hash
)
def
has_cache
s
(
def
has_cache
_item
(
self
,
request
:
"Reque
st
"
,
)
->
list
[
bool
]
:
identifier
:
st
r
,
)
->
bool
:
"""
Check if cache exist externally for
each mm_data of request
Check if cache exist externally for
the media
Args:
request (Reque
st): the
request object
.
identifier (
st
r
): the
identifier of the media
.
Returns:
List of b
ool indicate that
ith mm_dat
a exist in cache or not
B
ool indicate that
medi
a exist
s
in cache or not
"""
result
=
[]
for
feature
in
request
.
mm_features
:
result
.
append
(
self
.
_found_match_for_mm_data
(
feature
.
identifier
))
return
result
return
self
.
_found_match_for_mm_data
(
identifier
)
def
update_state_after_alloc
(
self
,
...
...
vllm/v1/core/sched/scheduler.py
View file @
378385b9
...
...
@@ -947,9 +947,6 @@ class Scheduler(SchedulerInterface):
assert
len
(
mm_features
)
>
0
external_load_encoder_input
=
[]
# Check remote cache first
if
self
.
ec_connector
is
not
None
:
remote_cache_has_item
=
self
.
ec_connector
.
has_caches
(
request
)
# NOTE: since scheduler operates on the request level (possibly with
# multiple encoder inputs per request), we need to create temporary
# trackers for accounting at the encoder input level.
...
...
@@ -959,6 +956,7 @@ class Scheduler(SchedulerInterface):
start_pos
=
mm_feature
.
mm_position
.
offset
num_encoder_tokens
=
mm_feature
.
mm_position
.
length
num_encoder_embeds
=
mm_feature
.
mm_position
.
get_num_embeds
item_identifier
=
mm_feature
.
identifier
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
...
...
@@ -993,7 +991,7 @@ class Scheduler(SchedulerInterface):
if
not
self
.
is_encoder_decoder
:
# We are not using the encoder cache for encoder-decoder models,
# yet.
if
request
.
mm_features
[
i
].
identifier
in
mm_hashes_to_schedule
:
if
item_
identifier
in
mm_hashes_to_schedule
:
# The same encoder input has already been scheduled in the
# current step.
continue
...
...
@@ -1051,15 +1049,17 @@ class Scheduler(SchedulerInterface):
if
curr_embeds_end
-
curr_embeds_start
==
0
:
continue
if
self
.
ec_connector
is
not
None
and
remote_cache_has_item
[
i
]:
mm_hashes_to_schedule
.
add
(
request
.
mm_features
[
i
].
identifier
)
if
self
.
ec_connector
is
not
None
and
self
.
ec_connector
.
has_cache_item
(
item_identifier
):
mm_hashes_to_schedule
.
add
(
item_identifier
)
external_load_encoder_input
.
append
(
i
)
num_embeds_to_schedule
+=
num_encoder_embeds
continue
num_embeds_to_schedule
+=
num_encoder_embeds
encoder_compute_budget
-=
num_encoder_embeds
mm_hashes_to_schedule
.
add
(
request
.
mm_features
[
i
].
identifier
)
mm_hashes_to_schedule
.
add
(
item_
identifier
)
encoder_inputs_to_schedule
.
append
(
i
)
return
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment