Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f4417f84
Unverified
Commit
f4417f84
authored
Dec 11, 2025
by
Martin Hickey
Committed by
GitHub
Dec 11, 2025
Browse files
[KVConnector] Add KV events to KV Connectors (#28309)
Signed-off-by:
Martin Hickey
<
martin.hickey@ie.ibm.com
>
parent
a11f4a81
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1036 additions
and
15 deletions
+1036
-15
tests/v1/kv_connector/unit/test_lmcache_connector.py
tests/v1/kv_connector/unit/test_lmcache_connector.py
+756
-0
vllm/distributed/kv_events.py
vllm/distributed/kv_events.py
+129
-1
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+15
-0
vllm/distributed/kv_transfer/kv_connector/v1/base.py
vllm/distributed/kv_transfer/kv_connector/v1/base.py
+9
-1
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
...tributed/kv_transfer/kv_connector/v1/lmcache_connector.py
+114
-3
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
...istributed/kv_transfer/kv_connector/v1/multi_connector.py
+6
-0
vllm/v1/outputs.py
vllm/v1/outputs.py
+4
-0
vllm/v1/worker/kv_connector_model_runner_mixin.py
vllm/v1/worker/kv_connector_model_runner_mixin.py
+3
-10
No files found.
tests/v1/kv_connector/unit/test_lmcache_connector.py
0 → 100644
View file @
f4417f84
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
unittest.mock
import
MagicMock
import
pytest
from
vllm.distributed.kv_events
import
BlockStored
from
vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector
import
(
LMCacheConnectorV1
,
LMCacheKVEvents
,
)
from
vllm.v1.outputs
import
KVConnectorOutput
@
pytest
.
fixture
def
mock_lmcache_engine_event
():
"""Create a mock event object that mimics what the lmcache engine returns."""
class
MockEvent
:
def
__init__
(
self
,
block_hashes
,
parent_block_hash
,
token_ids
,
lora_id
,
block_size
,
medium
,
):
self
.
block_hashes
=
block_hashes
self
.
parent_block_hash
=
parent_block_hash
self
.
token_ids
=
token_ids
self
.
lora_id
=
lora_id
self
.
block_size
=
block_size
self
.
medium
=
medium
return
MockEvent
(
block_hashes
=
[
"hash1"
,
"hash2"
],
parent_block_hash
=
"parent_hash"
,
token_ids
=
[
1
,
2
,
3
,
4
],
lora_id
=
None
,
block_size
=
16
,
medium
=
"GPU"
,
)
@
pytest
.
fixture
def
mock_connector
():
"""Create a mock LMCacheConnectorV1 instance with mocked dependencies."""
connector
=
MagicMock
(
spec
=
LMCacheConnectorV1
)
connector
.
_kv_cache_events
=
None
connector
.
_lmcache_engine
=
MagicMock
()
# Make the methods use the real implementation
connector
.
get_kv_connector_kv_cache_events
=
(
LMCacheConnectorV1
.
get_kv_connector_kv_cache_events
.
__get__
(
connector
,
LMCacheConnectorV1
)
)
connector
.
update_connector_output
=
(
LMCacheConnectorV1
.
update_connector_output
.
__get__
(
connector
,
LMCacheConnectorV1
)
)
connector
.
take_events
=
LMCacheConnectorV1
.
take_events
.
__get__
(
connector
,
LMCacheConnectorV1
)
return
connector
class
TestGetKVConnectorKVCacheEvents
:
"""Test get_kv_connector_kv_cache_events method."""
def
test_returns_none_when_no_events
(
self
,
mock_connector
):
"""Test that None is returned when lmcache engine has no events."""
mock_connector
.
_lmcache_engine
.
get_kv_events
.
return_value
=
None
result
=
mock_connector
.
get_kv_connector_kv_cache_events
()
assert
result
is
None
mock_connector
.
_lmcache_engine
.
get_kv_events
.
assert_called_once
()
def
test_returns_none_when_empty_list
(
self
,
mock_connector
):
"""Test that None is returned when lmcache engine returns empty list."""
mock_connector
.
_lmcache_engine
.
get_kv_events
.
return_value
=
[]
result
=
mock_connector
.
get_kv_connector_kv_cache_events
()
assert
result
is
None
def
test_converts_single_event
(
self
,
mock_connector
,
mock_lmcache_engine_event
):
"""Test conversion of a single event from lmcache engine format."""
mock_connector
.
_lmcache_engine
.
get_kv_events
.
return_value
=
[
mock_lmcache_engine_event
]
result
=
mock_connector
.
get_kv_connector_kv_cache_events
()
assert
result
is
not
None
assert
isinstance
(
result
,
LMCacheKVEvents
)
assert
result
.
get_number_of_workers
()
==
1
events
=
result
.
get_all_events
()
assert
len
(
events
)
==
1
assert
isinstance
(
events
[
0
],
BlockStored
)
assert
events
[
0
].
block_hashes
==
[
"hash1"
,
"hash2"
]
assert
events
[
0
].
parent_block_hash
==
"parent_hash"
assert
events
[
0
].
token_ids
==
[
1
,
2
,
3
,
4
]
assert
events
[
0
].
lora_id
is
None
assert
events
[
0
].
block_size
==
16
assert
events
[
0
].
medium
==
"GPU"
def
test_converts_multiple_events
(
self
,
mock_connector
):
"""Test conversion of multiple events from lmcache engine format."""
class
MockEvent
:
def
__init__
(
self
,
i
):
self
.
block_hashes
=
[
f
"hash
{
i
}
"
]
self
.
parent_block_hash
=
f
"parent
{
i
}
"
self
.
token_ids
=
[
i
]
self
.
lora_id
=
None
self
.
block_size
=
16
self
.
medium
=
"GPU"
events
=
[
MockEvent
(
i
)
for
i
in
range
(
5
)]
mock_connector
.
_lmcache_engine
.
get_kv_events
.
return_value
=
events
result
=
mock_connector
.
get_kv_connector_kv_cache_events
()
assert
result
is
not
None
assert
isinstance
(
result
,
LMCacheKVEvents
)
converted_events
=
result
.
get_all_events
()
assert
len
(
converted_events
)
==
5
for
i
,
event
in
enumerate
(
converted_events
):
assert
isinstance
(
event
,
BlockStored
)
assert
event
.
block_hashes
==
[
f
"hash
{
i
}
"
]
assert
event
.
parent_block_hash
==
f
"parent
{
i
}
"
assert
event
.
token_ids
==
[
i
]
def
test_preserves_event_attributes
(
self
,
mock_connector
):
"""Test that all event attributes are correctly preserved."""
class
MockEventWithLora
:
def
__init__
(
self
):
self
.
block_hashes
=
[
"hash_a"
,
"hash_b"
,
"hash_c"
]
self
.
parent_block_hash
=
"parent_xyz"
self
.
token_ids
=
[
100
,
200
,
300
]
self
.
lora_id
=
42
self
.
block_size
=
32
self
.
medium
=
"DISK"
mock_connector
.
_lmcache_engine
.
get_kv_events
.
return_value
=
[
MockEventWithLora
()
]
result
=
mock_connector
.
get_kv_connector_kv_cache_events
()
events
=
result
.
get_all_events
()
event
=
events
[
0
]
assert
event
.
block_hashes
==
[
"hash_a"
,
"hash_b"
,
"hash_c"
]
assert
event
.
parent_block_hash
==
"parent_xyz"
assert
event
.
token_ids
==
[
100
,
200
,
300
]
assert
event
.
lora_id
==
42
assert
event
.
block_size
==
32
assert
event
.
medium
==
"DISK"
def
test_handles_none_parent_block_hash
(
self
,
mock_connector
):
"""Test handling of events with None parent_block_hash."""
class
MockEventNoParent
:
def
__init__
(
self
):
self
.
block_hashes
=
[
"hash1"
]
self
.
parent_block_hash
=
None
self
.
token_ids
=
[
1
,
2
]
self
.
lora_id
=
None
self
.
block_size
=
16
self
.
medium
=
"GPU"
mock_connector
.
_lmcache_engine
.
get_kv_events
.
return_value
=
[
MockEventNoParent
()
]
result
=
mock_connector
.
get_kv_connector_kv_cache_events
()
events
=
result
.
get_all_events
()
assert
events
[
0
].
parent_block_hash
is
None
class
TestUpdateConnectorOutput
:
"""Test update_connector_output method."""
def
test_does_nothing_when_kv_cache_events_is_none
(
self
,
mock_connector
):
"""Test that method returns early when kv_cache_events is None."""
connector_output
=
KVConnectorOutput
(
kv_cache_events
=
None
)
mock_connector
.
update_connector_output
(
connector_output
)
assert
mock_connector
.
_kv_cache_events
is
None
def
test_does_nothing_when_kv_cache_events_is_not_lmcache_kv_events
(
self
,
mock_connector
):
"""Test that method returns early when kv_cache_events is not
LMCacheKVEvents."""
# Create a mock object that is not LMCacheKVEvents
fake_events
=
MagicMock
()
connector_output
=
KVConnectorOutput
(
kv_cache_events
=
fake_events
)
mock_connector
.
update_connector_output
(
connector_output
)
assert
mock_connector
.
_kv_cache_events
is
None
def
test_sets_kv_cache_events_when_none
(
self
,
mock_connector
):
"""Test that _kv_cache_events is set when it was None."""
kv_events
=
LMCacheKVEvents
(
num_workers
=
1
)
event
=
BlockStored
(
block_hashes
=
[
"hash1"
],
parent_block_hash
=
None
,
token_ids
=
[
1
,
2
],
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
)
kv_events
.
add_events
([
event
])
connector_output
=
KVConnectorOutput
(
kv_cache_events
=
kv_events
)
mock_connector
.
update_connector_output
(
connector_output
)
assert
mock_connector
.
_kv_cache_events
is
kv_events
def
test_adds_events_when_kv_cache_events_already_exists
(
self
,
mock_connector
):
"""Test that events are added when _kv_cache_events already exists."""
# Set up existing events
existing_events
=
LMCacheKVEvents
(
num_workers
=
2
)
event1
=
BlockStored
(
block_hashes
=
[
"hash1"
],
parent_block_hash
=
None
,
token_ids
=
[
1
],
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
)
existing_events
.
add_events
([
event1
])
existing_events
.
add_events
([
event1
])
# Simulate 2 workers reporting
mock_connector
.
_kv_cache_events
=
existing_events
# Create new events to add
new_events
=
LMCacheKVEvents
(
num_workers
=
1
)
event2
=
BlockStored
(
block_hashes
=
[
"hash2"
],
parent_block_hash
=
None
,
token_ids
=
[
2
],
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
)
new_events
.
add_events
([
event2
])
connector_output
=
KVConnectorOutput
(
kv_cache_events
=
new_events
)
mock_connector
.
update_connector_output
(
connector_output
)
# Check that events were added
all_events
=
mock_connector
.
_kv_cache_events
.
get_all_events
()
assert
len
(
all_events
)
==
3
# 2 from existing + 1 from new
assert
event1
in
all_events
assert
event2
in
all_events
def
test_increments_workers_when_kv_cache_events_already_exists
(
self
,
mock_connector
):
"""Test that worker count is incremented correctly."""
# Set up existing events with 2 workers
existing_events
=
LMCacheKVEvents
(
num_workers
=
2
)
mock_connector
.
_kv_cache_events
=
existing_events
# Create new events from 3 workers
new_events
=
LMCacheKVEvents
(
num_workers
=
3
)
event
=
BlockStored
(
block_hashes
=
[
"hash1"
],
parent_block_hash
=
None
,
token_ids
=
[
1
],
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
)
new_events
.
add_events
([
event
])
connector_output
=
KVConnectorOutput
(
kv_cache_events
=
new_events
)
mock_connector
.
update_connector_output
(
connector_output
)
# Worker count should be 2 + 3 = 5
assert
mock_connector
.
_kv_cache_events
.
get_number_of_workers
()
==
5
def
test_multiple_updates
(
self
,
mock_connector
):
"""Test multiple consecutive updates."""
# First update
events1
=
LMCacheKVEvents
(
num_workers
=
1
)
event1
=
BlockStored
(
block_hashes
=
[
"hash1"
],
parent_block_hash
=
None
,
token_ids
=
[
1
],
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
)
events1
.
add_events
([
event1
])
output1
=
KVConnectorOutput
(
kv_cache_events
=
events1
)
mock_connector
.
update_connector_output
(
output1
)
# Second update
events2
=
LMCacheKVEvents
(
num_workers
=
2
)
event2
=
BlockStored
(
block_hashes
=
[
"hash2"
],
parent_block_hash
=
None
,
token_ids
=
[
2
],
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
)
events2
.
add_events
([
event2
])
output2
=
KVConnectorOutput
(
kv_cache_events
=
events2
)
mock_connector
.
update_connector_output
(
output2
)
# Third update
events3
=
LMCacheKVEvents
(
num_workers
=
1
)
event3
=
BlockStored
(
block_hashes
=
[
"hash3"
],
parent_block_hash
=
None
,
token_ids
=
[
3
],
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
)
events3
.
add_events
([
event3
])
output3
=
KVConnectorOutput
(
kv_cache_events
=
events3
)
mock_connector
.
update_connector_output
(
output3
)
# Check final state
all_events
=
mock_connector
.
_kv_cache_events
.
get_all_events
()
assert
len
(
all_events
)
==
3
assert
mock_connector
.
_kv_cache_events
.
get_number_of_workers
()
==
4
# 1+2+1
def
test_updates_with_empty_events
(
self
,
mock_connector
):
"""Test updating with empty event lists."""
# First update with actual events
events1
=
LMCacheKVEvents
(
num_workers
=
1
)
event1
=
BlockStored
(
block_hashes
=
[
"hash1"
],
parent_block_hash
=
None
,
token_ids
=
[
1
],
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
)
events1
.
add_events
([
event1
])
output1
=
KVConnectorOutput
(
kv_cache_events
=
events1
)
mock_connector
.
update_connector_output
(
output1
)
# Second update with empty events
events2
=
LMCacheKVEvents
(
num_workers
=
2
)
# No events added
output2
=
KVConnectorOutput
(
kv_cache_events
=
events2
)
mock_connector
.
update_connector_output
(
output2
)
# Should still have the original event
all_events
=
mock_connector
.
_kv_cache_events
.
get_all_events
()
assert
len
(
all_events
)
==
1
assert
mock_connector
.
_kv_cache_events
.
get_number_of_workers
()
==
3
class
TestTakeEvents
:
"""Test take_events method."""
def
test_yields_nothing_when_kv_cache_events_is_none
(
self
,
mock_connector
):
"""Test that nothing is yielded when _kv_cache_events is None."""
mock_connector
.
_kv_cache_events
=
None
events
=
list
(
mock_connector
.
take_events
())
assert
events
==
[]
def
test_yields_events_and_clears
(
self
,
mock_connector
):
"""Test that events are yielded and then cleared."""
# Set up events
kv_events
=
LMCacheKVEvents
(
num_workers
=
1
)
event1
=
BlockStored
(
block_hashes
=
[
"hash1"
],
parent_block_hash
=
None
,
token_ids
=
[
1
],
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
)
event2
=
BlockStored
(
block_hashes
=
[
"hash2"
],
parent_block_hash
=
None
,
token_ids
=
[
2
],
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
)
kv_events
.
add_events
([
event1
,
event2
])
mock_connector
.
_kv_cache_events
=
kv_events
# Take events
events
=
list
(
mock_connector
.
take_events
())
# Check that events were yielded
assert
len
(
events
)
==
2
assert
event1
in
events
assert
event2
in
events
# Check that _kv_cache_events was cleared
assert
mock_connector
.
_kv_cache_events
is
None
def
test_aggregates_before_yielding
(
self
,
mock_connector
):
"""Test that events are aggregated before yielding."""
# Set up events from multiple workers
kv_events
=
LMCacheKVEvents
(
num_workers
=
3
)
common_event
=
BlockStored
(
block_hashes
=
[
"hash_common"
],
parent_block_hash
=
None
,
token_ids
=
[
1
],
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
)
uncommon_event
=
BlockStored
(
block_hashes
=
[
"hash_uncommon"
],
parent_block_hash
=
None
,
token_ids
=
[
2
],
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
)
# All 3 workers report common_event
kv_events
.
add_events
([
common_event
])
kv_events
.
add_events
([
common_event
])
kv_events
.
add_events
([
common_event
])
# Only 1 worker reports uncommon_event
kv_events
.
add_events
([
uncommon_event
])
mock_connector
.
_kv_cache_events
=
kv_events
# Take events
events
=
list
(
mock_connector
.
take_events
())
# Only the common event should be yielded
assert
len
(
events
)
==
1
assert
events
[
0
]
==
common_event
def
test_multiple_take_events_calls
(
self
,
mock_connector
):
"""Test calling take_events multiple times."""
# First call with events
kv_events1
=
LMCacheKVEvents
(
num_workers
=
1
)
event1
=
BlockStored
(
block_hashes
=
[
"hash1"
],
parent_block_hash
=
None
,
token_ids
=
[
1
],
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
)
kv_events1
.
add_events
([
event1
])
mock_connector
.
_kv_cache_events
=
kv_events1
events1
=
list
(
mock_connector
.
take_events
())
assert
len
(
events1
)
==
1
assert
events1
[
0
]
==
event1
assert
mock_connector
.
_kv_cache_events
is
None
# Second call with no events
events2
=
list
(
mock_connector
.
take_events
())
assert
events2
==
[]
# Third call after adding new events
kv_events2
=
LMCacheKVEvents
(
num_workers
=
1
)
event2
=
BlockStored
(
block_hashes
=
[
"hash2"
],
parent_block_hash
=
None
,
token_ids
=
[
2
],
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
)
kv_events2
.
add_events
([
event2
])
mock_connector
.
_kv_cache_events
=
kv_events2
events3
=
list
(
mock_connector
.
take_events
())
assert
len
(
events3
)
==
1
assert
events3
[
0
]
==
event2
def
test_yields_empty_after_aggregation_removes_all
(
self
,
mock_connector
):
"""Test that nothing is yielded if aggregation removes all events."""
# Set up events from 2 workers with no common events
kv_events
=
LMCacheKVEvents
(
num_workers
=
2
)
event1
=
BlockStored
(
block_hashes
=
[
"hash1"
],
parent_block_hash
=
None
,
token_ids
=
[
1
],
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
)
event2
=
BlockStored
(
block_hashes
=
[
"hash2"
],
parent_block_hash
=
None
,
token_ids
=
[
2
],
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
)
# Worker 1 reports event1
kv_events
.
add_events
([
event1
])
# Worker 2 reports event2
kv_events
.
add_events
([
event2
])
mock_connector
.
_kv_cache_events
=
kv_events
# Take events
events
=
list
(
mock_connector
.
take_events
())
# No common events, so nothing should be yielded
assert
events
==
[]
assert
mock_connector
.
_kv_cache_events
is
None
class
TestIntegrationScenarios
:
"""Test integration scenarios."""
def
test_full_workflow
(
self
,
mock_connector
,
mock_lmcache_engine_event
):
"""Test a complete workflow from getting events to taking them."""
# Step 1: Get events from lmcache engine
mock_connector
.
_lmcache_engine
.
get_kv_events
.
return_value
=
[
mock_lmcache_engine_event
]
kv_events
=
mock_connector
.
get_kv_connector_kv_cache_events
()
assert
kv_events
is
not
None
assert
len
(
kv_events
.
get_all_events
())
==
1
# Step 2: Update connector output (simulate receiving from worker)
output1
=
KVConnectorOutput
(
kv_cache_events
=
kv_events
)
mock_connector
.
update_connector_output
(
output1
)
assert
mock_connector
.
_kv_cache_events
is
not
None
# Step 3: Take events
taken_events
=
list
(
mock_connector
.
take_events
())
assert
len
(
taken_events
)
==
1
assert
mock_connector
.
_kv_cache_events
is
None
def
test_multiple_workers_workflow
(
self
,
mock_connector
):
"""Test workflow with multiple workers."""
class
MockEvent
:
def
__init__
(
self
,
hash_val
):
self
.
block_hashes
=
[
hash_val
]
self
.
parent_block_hash
=
None
self
.
token_ids
=
[
1
]
self
.
lora_id
=
None
self
.
block_size
=
16
self
.
medium
=
"GPU"
# Worker 1
mock_connector
.
_lmcache_engine
.
get_kv_events
.
return_value
=
[
MockEvent
(
"hash_common"
),
MockEvent
(
"hash_worker1"
),
]
kv_events1
=
mock_connector
.
get_kv_connector_kv_cache_events
()
output1
=
KVConnectorOutput
(
kv_cache_events
=
kv_events1
)
mock_connector
.
update_connector_output
(
output1
)
# Worker 2
mock_connector
.
_lmcache_engine
.
get_kv_events
.
return_value
=
[
MockEvent
(
"hash_common"
),
MockEvent
(
"hash_worker2"
),
]
kv_events2
=
mock_connector
.
get_kv_connector_kv_cache_events
()
output2
=
KVConnectorOutput
(
kv_cache_events
=
kv_events2
)
mock_connector
.
update_connector_output
(
output2
)
# Take events (should only get common events)
taken_events
=
list
(
mock_connector
.
take_events
())
# With aggregation, only events reported by both workers should be present
# In this case, hash_common was reported by both
event_hashes
=
[
e
.
block_hashes
[
0
]
for
e
in
taken_events
]
assert
"hash_common"
in
event_hashes
def
test_empty_workflow
(
self
,
mock_connector
):
"""Test workflow when there are no events at any stage."""
# Get events returns None
mock_connector
.
_lmcache_engine
.
get_kv_events
.
return_value
=
None
kv_events
=
mock_connector
.
get_kv_connector_kv_cache_events
()
assert
kv_events
is
None
# Update with None
output
=
KVConnectorOutput
(
kv_cache_events
=
None
)
mock_connector
.
update_connector_output
(
output
)
# Take events
taken_events
=
list
(
mock_connector
.
take_events
())
assert
taken_events
==
[]
assert
mock_connector
.
_kv_cache_events
is
None
def
test_repeated_cycles
(
self
,
mock_connector
):
"""Test multiple cycles of the complete workflow."""
class
MockEvent
:
def
__init__
(
self
,
cycle_num
):
self
.
block_hashes
=
[
f
"hash_cycle_
{
cycle_num
}
"
]
self
.
parent_block_hash
=
None
self
.
token_ids
=
[
cycle_num
]
self
.
lora_id
=
None
self
.
block_size
=
16
self
.
medium
=
"GPU"
for
cycle
in
range
(
3
):
# Get events
mock_connector
.
_lmcache_engine
.
get_kv_events
.
return_value
=
[
MockEvent
(
cycle
)
]
kv_events
=
mock_connector
.
get_kv_connector_kv_cache_events
()
# Update
output
=
KVConnectorOutput
(
kv_cache_events
=
kv_events
)
mock_connector
.
update_connector_output
(
output
)
# Take
taken_events
=
list
(
mock_connector
.
take_events
())
# Verify
assert
len
(
taken_events
)
==
1
assert
taken_events
[
0
].
block_hashes
[
0
]
==
f
"hash_cycle_
{
cycle
}
"
assert
mock_connector
.
_kv_cache_events
is
None
def
test_lmcache_kv_events_aggregation
(
self
):
"""
Test LMCacheKVEvents aggregation across TP ranks using
KVOutputAggregator (used by MultiprocExecutor).
"""
from
vllm.distributed.kv_transfer.kv_connector.utils
import
KVOutputAggregator
from
vllm.v1.outputs
import
ModelRunnerOutput
# Create KVOutputAggregator for 3 workers (simulating TP=3)
aggregator
=
KVOutputAggregator
(
expected_finished_count
=
3
)
# Define common and unique events
common_event
=
BlockStored
(
block_hashes
=
[
"hash_common"
],
parent_block_hash
=
"parent_common"
,
token_ids
=
[
1
,
2
,
3
],
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
)
worker1_unique_event
=
BlockStored
(
block_hashes
=
[
"hash_worker1"
],
parent_block_hash
=
"parent_w1"
,
token_ids
=
[
4
,
5
],
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
)
worker2_unique_event
=
BlockStored
(
block_hashes
=
[
"hash_worker2"
],
parent_block_hash
=
"parent_w2"
,
token_ids
=
[
6
,
7
],
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
)
worker3_unique_event
=
BlockStored
(
block_hashes
=
[
"hash_worker3"
],
parent_block_hash
=
"parent_w3"
,
token_ids
=
[
8
,
9
],
block_size
=
16
,
lora_id
=
None
,
medium
=
"GPU"
,
)
# Create events for each worker
# Worker 0: reports common event and its unique event
worker0_events
=
LMCacheKVEvents
(
num_workers
=
1
)
worker0_events
.
add_events
([
common_event
,
worker1_unique_event
])
# Worker 1: reports common event and its unique event
worker1_events
=
LMCacheKVEvents
(
num_workers
=
1
)
worker1_events
.
add_events
([
common_event
,
worker2_unique_event
])
# Worker 2: reports common event and its unique event
worker2_events
=
LMCacheKVEvents
(
num_workers
=
1
)
worker2_events
.
add_events
([
common_event
,
worker3_unique_event
])
# Create ModelRunnerOutput instances for each worker
worker_outputs
=
[]
for
i
,
worker_events
in
enumerate
(
[
worker0_events
,
worker1_events
,
worker2_events
]
):
output
=
ModelRunnerOutput
(
req_ids
=
[
f
"req_
{
i
}
"
],
req_id_to_index
=
{
f
"req_
{
i
}
"
:
0
},
sampled_token_ids
=
[[
123
]],
# dummy token
logprobs
=
None
,
prompt_logprobs_dict
=
{},
pooler_output
=
[
None
],
kv_connector_output
=
KVConnectorOutput
(
finished_sending
=
set
([
f
"req_
{
i
}
_send"
])
if
i
<
2
else
None
,
# Workers 0,1 finished sending
finished_recving
=
set
([
f
"req_
{
i
}
_recv"
])
if
i
>
0
else
None
,
# Workers 1,2 finished receiving
kv_cache_events
=
worker_events
,
),
)
worker_outputs
.
append
(
output
)
# Use the real aggregation mechanism (like MultiprocExecutor.execute_model)
aggregated_output
=
aggregator
.
aggregate
(
worker_outputs
,
output_rank
=
0
)
kv_cache_events
=
aggregated_output
.
kv_connector_output
.
kv_cache_events
assert
isinstance
(
kv_cache_events
,
LMCacheKVEvents
)
# After aggregation, events should be combined from all workers
# The aggregator doesn't automatically aggregate events, so we need to call
# aggregate() to get only common events
kv_cache_events
.
aggregate
()
aggregated_events
=
kv_cache_events
.
get_all_events
()
# Only the common event should remain after aggregation
# because it's the only event reported by all 3 workers
assert
len
(
aggregated_events
)
==
1
assert
aggregated_events
[
0
]
==
common_event
# Verify the common event properties
assert
aggregated_events
[
0
].
block_hashes
==
[
"hash_common"
]
assert
aggregated_events
[
0
].
parent_block_hash
==
"parent_common"
assert
aggregated_events
[
0
].
token_ids
==
[
1
,
2
,
3
]
vllm/distributed/kv_events.py
View file @
f4417f84
...
@@ -5,7 +5,7 @@ import queue
...
@@ -5,7 +5,7 @@ import queue
import
threading
import
threading
import
time
import
time
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections
import
deque
from
collections
import
Counter
,
deque
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
dataclasses
import
asdict
from
dataclasses
import
asdict
from
itertools
import
count
from
itertools
import
count
...
@@ -54,11 +54,26 @@ class BlockStored(KVCacheEvent):
...
@@ -54,11 +54,26 @@ class BlockStored(KVCacheEvent):
lora_id
:
int
|
None
lora_id
:
int
|
None
medium
:
str
|
None
medium
:
str
|
None
def
__hash__
(
self
)
->
int
:
return
hash
(
(
tuple
(
self
.
block_hashes
),
self
.
parent_block_hash
,
tuple
(
self
.
token_ids
),
self
.
block_size
,
self
.
lora_id
,
self
.
medium
,
)
)
class
BlockRemoved
(
KVCacheEvent
):
class
BlockRemoved
(
KVCacheEvent
):
block_hashes
:
list
[
ExternalBlockHash
]
block_hashes
:
list
[
ExternalBlockHash
]
medium
:
str
|
None
medium
:
str
|
None
def
__hash__
(
self
)
->
int
:
return
hash
((
tuple
(
self
.
block_hashes
),
self
.
medium
))
class
AllBlocksCleared
(
KVCacheEvent
):
class
AllBlocksCleared
(
KVCacheEvent
):
pass
pass
...
@@ -68,6 +83,119 @@ class KVEventBatch(EventBatch):
...
@@ -68,6 +83,119 @@ class KVEventBatch(EventBatch):
events
:
list
[
BlockStored
|
BlockRemoved
|
AllBlocksCleared
]
events
:
list
[
BlockStored
|
BlockRemoved
|
AllBlocksCleared
]
class
KVEventAggregator
:
"""
Aggregates KV events across multiple workers.
Tracks how many times each event appears and returns only those
that were emitted by all workers.
"""
__slots__
=
(
"_event_counter"
,
"_num_workers"
)
def
__init__
(
self
,
num_workers
:
int
)
->
None
:
if
num_workers
<=
0
:
raise
ValueError
(
"num_workers must be greater than zero."
)
self
.
_event_counter
:
Counter
[
KVCacheEvent
]
=
Counter
()
self
.
_num_workers
:
int
=
num_workers
def
add_events
(
self
,
events
:
list
[
KVCacheEvent
])
->
None
:
"""
Add events from a worker batch.
:param events: List of KVCacheEvent objects.
"""
if
not
isinstance
(
events
,
list
):
raise
TypeError
(
"events must be a list of KVCacheEvent."
)
self
.
_event_counter
.
update
(
events
)
def
get_common_events
(
self
)
->
list
[
KVCacheEvent
]:
"""
Return events that appeared in all workers.
:return: List of events present in all workers.
"""
return
[
event
for
event
,
count
in
self
.
_event_counter
.
items
()
if
count
==
self
.
_num_workers
]
def
get_all_events
(
self
)
->
list
[
KVCacheEvent
]:
"""
Return all events for all workers.
:return: List of events for all workers.
"""
return
list
(
self
.
_event_counter
.
elements
())
def
clear_events
(
self
)
->
None
:
"""
Clear all tracked events.
"""
self
.
_event_counter
.
clear
()
def
increment_workers
(
self
,
count
:
int
=
1
)
->
None
:
"""
Increment the number of workers contributing events.
:param count: Number to increment the workers by.
"""
if
count
<=
0
:
raise
ValueError
(
"count must be positive."
)
self
.
_num_workers
+=
count
def
reset_workers
(
self
)
->
None
:
"""
Reset the number of workers to 1.
"""
self
.
_num_workers
=
1
def
get_number_of_workers
(
self
)
->
int
:
"""
Return the number of workers.
:return: int number of workers.
"""
return
self
.
_num_workers
def
__repr__
(
self
)
->
str
:
return
(
f
"<KVEventAggregator workers=
{
self
.
_num_workers
}
, "
f
"events=
{
len
(
self
.
_event_counter
)
}
>"
)
class
KVConnectorKVEvents
(
ABC
):
"""
Abstract base class for KV events.
Acts as a container for KV events from the connector.
"""
@
abstractmethod
def
add_events
(
self
,
events
:
list
[
KVCacheEvent
])
->
None
:
raise
NotImplementedError
@
abstractmethod
def
aggregate
(
self
)
->
"KVConnectorKVEvents"
:
raise
NotImplementedError
@
abstractmethod
def
increment_workers
(
self
,
count
:
int
=
1
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
get_all_events
(
self
)
->
list
[
KVCacheEvent
]:
raise
NotImplementedError
@
abstractmethod
def
get_number_of_workers
(
self
)
->
int
:
raise
NotImplementedError
@
abstractmethod
def
clear_events
(
self
)
->
None
:
raise
NotImplementedError
class
EventPublisher
(
ABC
):
class
EventPublisher
(
ABC
):
"""Lightweight publisher for EventBatch batches with data parallelism
"""Lightweight publisher for EventBatch batches with data parallelism
support.
support.
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
f4417f84
...
@@ -78,6 +78,7 @@ class KVOutputAggregator:
...
@@ -78,6 +78,7 @@ class KVOutputAggregator:
finished_sending
=
set
[
str
]()
finished_sending
=
set
[
str
]()
finished_recving
=
set
[
str
]()
finished_recving
=
set
[
str
]()
aggregated_kv_connector_stats
=
None
aggregated_kv_connector_stats
=
None
combined_kv_cache_events
=
None
invalid_block_ids
=
set
[
int
]()
invalid_block_ids
=
set
[
int
]()
for
model_runner_output
in
outputs
:
for
model_runner_output
in
outputs
:
assert
model_runner_output
is
not
None
assert
model_runner_output
is
not
None
...
@@ -119,6 +120,19 @@ class KVOutputAggregator:
...
@@ -119,6 +120,19 @@ class KVOutputAggregator:
aggregated_kv_connector_stats
.
aggregate
(
kv_connector_stats
)
aggregated_kv_connector_stats
.
aggregate
(
kv_connector_stats
)
)
)
# Combine kv_cache_events from all workers.
if
combined_kv_cache_events
is
None
:
# Use the first worker's kv_cache events as start event list.
combined_kv_cache_events
=
kv_output
.
kv_cache_events
elif
kv_cache_events
:
=
kv_output
.
kv_cache_events
:
assert
isinstance
(
combined_kv_cache_events
,
type
(
kv_cache_events
),
)
worker_kv_cache_events
=
kv_cache_events
.
get_all_events
()
combined_kv_cache_events
.
add_events
(
worker_kv_cache_events
)
combined_kv_cache_events
.
increment_workers
(
1
)
invalid_block_ids
|=
kv_output
.
invalid_block_ids
invalid_block_ids
|=
kv_output
.
invalid_block_ids
# select output of the worker specified by output_rank
# select output of the worker specified by output_rank
...
@@ -129,6 +143,7 @@ class KVOutputAggregator:
...
@@ -129,6 +143,7 @@ class KVOutputAggregator:
finished_sending
=
finished_sending
or
None
,
finished_sending
=
finished_sending
or
None
,
finished_recving
=
finished_recving
or
None
,
finished_recving
=
finished_recving
or
None
,
kv_connector_stats
=
aggregated_kv_connector_stats
or
None
,
kv_connector_stats
=
aggregated_kv_connector_stats
or
None
,
kv_cache_events
=
combined_kv_cache_events
or
None
,
invalid_block_ids
=
invalid_block_ids
,
invalid_block_ids
=
invalid_block_ids
,
expected_finished_count
=
self
.
_expected_finished_count
,
expected_finished_count
=
self
.
_expected_finished_count
,
)
)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/base.py
View file @
f4417f84
...
@@ -49,7 +49,7 @@ from vllm.v1.outputs import KVConnectorOutput
...
@@ -49,7 +49,7 @@ from vllm.v1.outputs import KVConnectorOutput
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.distributed.kv_events
import
KVCacheEvent
,
KVConnectorKVEvents
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorPromMetrics
,
KVConnectorPromMetrics
,
KVConnectorStats
,
KVConnectorStats
,
...
@@ -379,6 +379,14 @@ class KVConnectorBase_V1(ABC):
...
@@ -379,6 +379,14 @@ class KVConnectorBase_V1(ABC):
"""
"""
return
None
return
None
def
get_kv_connector_kv_cache_events
(
self
)
->
Optional
[
"KVConnectorKVEvents"
]:
"""
Get the KV connector kv cache events collected during the last interval.
This function should be called by the model runner every time after the
model execution and before cleanup.
"""
return
None
def
get_handshake_metadata
(
self
)
->
KVConnectorHandshakeMetadata
|
None
:
def
get_handshake_metadata
(
self
)
->
KVConnectorHandshakeMetadata
|
None
:
"""
"""
Get the KVConnector handshake metadata for this connector.
Get the KVConnector handshake metadata for this connector.
...
...
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
View file @
f4417f84
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
from
typing
import
TYPE_CHECKING
,
Any
from
typing
import
TYPE_CHECKING
,
Any
import
torch
import
torch
from
lmcache.integration.vllm.vllm_v1_adapter
import
(
LMCacheConnectorV1Impl
as
LMCacheConnectorLatestImpl
,
)
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_events
import
(
BlockStored
,
KVCacheEvent
,
KVConnectorKVEvents
,
KVEventAggregator
,
)
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorMetadata
,
...
@@ -16,6 +20,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
...
@@ -16,6 +20,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
)
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.outputs
import
KVConnectorOutput
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.forward_context
import
ForwardContext
from
vllm.forward_context
import
ForwardContext
...
@@ -26,6 +31,44 @@ if TYPE_CHECKING:
...
@@ -26,6 +31,44 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
LMCacheKVEvents
(
KVConnectorKVEvents
):
"""
Concrete implementation of KVConnectorKVEvents using KVEventAggregator.
"""
def
__init__
(
self
,
num_workers
:
int
)
->
None
:
self
.
_aggregator
=
KVEventAggregator
(
num_workers
)
def
add_events
(
self
,
events
:
list
[
KVCacheEvent
])
->
None
:
self
.
_aggregator
.
add_events
(
events
)
def
aggregate
(
self
)
->
"LMCacheKVEvents"
:
"""
Aggregate KV events and retain only common events.
"""
common_events
=
self
.
_aggregator
.
get_common_events
()
self
.
_aggregator
.
clear_events
()
self
.
_aggregator
.
add_events
(
common_events
)
self
.
_aggregator
.
reset_workers
()
return
self
def
increment_workers
(
self
,
count
:
int
=
1
)
->
None
:
self
.
_aggregator
.
increment_workers
(
count
)
def
get_all_events
(
self
)
->
list
[
KVCacheEvent
]:
return
self
.
_aggregator
.
get_all_events
()
def
get_number_of_workers
(
self
)
->
int
:
return
self
.
_aggregator
.
get_number_of_workers
()
def
clear_events
(
self
)
->
None
:
self
.
_aggregator
.
clear_events
()
self
.
_aggregator
.
reset_workers
()
def
__repr__
(
self
)
->
str
:
return
f
"<LMCacheKVEvents events=
{
self
.
get_all_events
()
}
>"
class
LMCacheConnectorV1
(
KVConnectorBase_V1
):
class
LMCacheConnectorV1
(
KVConnectorBase_V1
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -50,10 +93,17 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
...
@@ -50,10 +93,17 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
cls
=
_adapter
.
LMCacheConnectorV1Impl
cls
=
_adapter
.
LMCacheConnectorV1Impl
else
:
else
:
logger
.
info
(
"Initializing latest dev LMCache connector"
)
logger
.
info
(
"Initializing latest dev LMCache connector"
)
# lazy import
from
lmcache.integration.vllm.vllm_v1_adapter
import
(
LMCacheConnectorV1Impl
as
LMCacheConnectorLatestImpl
,
)
cls
=
LMCacheConnectorLatestImpl
cls
=
LMCacheConnectorLatestImpl
self
.
_lmcache_engine
=
cls
(
vllm_config
,
role
,
self
)
self
.
_lmcache_engine
=
cls
(
vllm_config
,
role
,
self
)
self
.
_kv_cache_events
:
LMCacheKVEvents
|
None
=
None
# ==============================
# ==============================
# Worker-side methods
# Worker-side methods
# ==============================
# ==============================
...
@@ -151,6 +201,31 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
...
@@ -151,6 +201,31 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
# Fallback for older versions that don't support this method
# Fallback for older versions that don't support this method
return
set
()
return
set
()
def
get_kv_connector_kv_cache_events
(
self
)
->
LMCacheKVEvents
|
None
:
"""
Get the KV connector kv cache events collected during the last interval.
"""
events
=
self
.
_lmcache_engine
.
get_kv_events
()
# type: ignore [attr-defined]
if
not
events
:
return
None
blocks
:
list
[
BlockStored
]
=
[
BlockStored
(
block_hashes
=
e
.
block_hashes
,
parent_block_hash
=
e
.
parent_block_hash
,
token_ids
=
e
.
token_ids
,
lora_id
=
e
.
lora_id
,
block_size
=
e
.
block_size
,
medium
=
e
.
medium
,
)
for
e
in
events
]
lmcache_kv_events
=
LMCacheKVEvents
(
num_workers
=
1
)
lmcache_kv_events
.
add_events
(
blocks
)
return
lmcache_kv_events
# ==============================
# ==============================
# Scheduler-side methods
# Scheduler-side methods
# ==============================
# ==============================
...
@@ -198,6 +273,28 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
...
@@ -198,6 +273,28 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
"""
"""
return
self
.
_lmcache_engine
.
build_connector_meta
(
scheduler_output
)
return
self
.
_lmcache_engine
.
build_connector_meta
(
scheduler_output
)
def
update_connector_output
(
self
,
connector_output
:
KVConnectorOutput
):
"""
Update KVConnector state from worker-side connectors output.
Args:
connector_output (KVConnectorOutput): the worker-side
connectors output.
"""
# Get the KV events
kv_cache_events
=
connector_output
.
kv_cache_events
if
not
kv_cache_events
or
not
isinstance
(
kv_cache_events
,
LMCacheKVEvents
):
return
if
self
.
_kv_cache_events
is
None
:
self
.
_kv_cache_events
=
kv_cache_events
else
:
self
.
_kv_cache_events
.
add_events
(
kv_cache_events
.
get_all_events
())
self
.
_kv_cache_events
.
increment_workers
(
kv_cache_events
.
get_number_of_workers
()
)
return
def
request_finished
(
def
request_finished
(
self
,
self
,
request
:
"Request"
,
request
:
"Request"
,
...
@@ -214,3 +311,17 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
...
@@ -214,3 +311,17 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
returned by the engine.
returned by the engine.
"""
"""
return
self
.
_lmcache_engine
.
request_finished
(
request
,
block_ids
)
return
self
.
_lmcache_engine
.
request_finished
(
request
,
block_ids
)
def
take_events
(
self
)
->
Iterable
[
"KVCacheEvent"
]:
"""
Take the KV cache events from the connector.
Yields:
New KV cache events since the last call.
"""
if
self
.
_kv_cache_events
is
not
None
:
self
.
_kv_cache_events
.
aggregate
()
kv_cache_events
=
self
.
_kv_cache_events
.
get_all_events
()
yield
from
kv_cache_events
self
.
_kv_cache_events
.
clear_events
()
self
.
_kv_cache_events
=
None
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
View file @
f4417f84
...
@@ -259,6 +259,12 @@ class MultiConnector(KVConnectorBase_V1):
...
@@ -259,6 +259,12 @@ class MultiConnector(KVConnectorBase_V1):
agg_block_ids
|=
c
.
get_block_ids_with_load_errors
()
agg_block_ids
|=
c
.
get_block_ids_with_load_errors
()
return
agg_block_ids
return
agg_block_ids
# TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events' method
# for the MultiConnector. It should be able to get events from multiple
# connectors, handling the case where only a subset of the requested connectors
# implements the 'get_kv_connector_kv_cache_events'
# Follow on PR from https://github.com/vllm-project/vllm/pull/28309#pullrequestreview-3566351082
# ==============================
# ==============================
# Scheduler-side methods
# Scheduler-side methods
# ==============================
# ==============================
...
...
vllm/v1/outputs.py
View file @
f4417f84
...
@@ -12,9 +12,11 @@ from vllm.compilation.cuda_graph import CUDAGraphStat
...
@@ -12,9 +12,11 @@ from vllm.compilation.cuda_graph import CUDAGraphStat
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.distributed.kv_events
import
KVConnectorKVEvents
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
KVConnectorStats
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
KVConnectorStats
else
:
else
:
KVConnectorStats
=
object
KVConnectorStats
=
object
KVConnectorKVEvents
=
object
class
LogprobsLists
(
NamedTuple
):
class
LogprobsLists
(
NamedTuple
):
...
@@ -108,6 +110,7 @@ class KVConnectorOutput:
...
@@ -108,6 +110,7 @@ class KVConnectorOutput:
finished_sending
:
set
[
str
]
|
None
=
None
finished_sending
:
set
[
str
]
|
None
=
None
finished_recving
:
set
[
str
]
|
None
=
None
finished_recving
:
set
[
str
]
|
None
=
None
kv_connector_stats
:
KVConnectorStats
|
None
=
None
kv_connector_stats
:
KVConnectorStats
|
None
=
None
kv_cache_events
:
KVConnectorKVEvents
|
None
=
None
# IDs of externally computed KV blocks that failed to load.
# IDs of externally computed KV blocks that failed to load.
# Requests referencing these blocks should be rescheduled to recompute them
# Requests referencing these blocks should be rescheduled to recompute them
invalid_block_ids
:
set
[
int
]
=
field
(
default_factory
=
set
)
invalid_block_ids
:
set
[
int
]
=
field
(
default_factory
=
set
)
...
@@ -123,6 +126,7 @@ class KVConnectorOutput:
...
@@ -123,6 +126,7 @@ class KVConnectorOutput:
not
self
.
finished_sending
not
self
.
finished_sending
and
not
self
.
finished_recving
and
not
self
.
finished_recving
and
not
self
.
kv_connector_stats
and
not
self
.
kv_connector_stats
and
not
self
.
kv_cache_events
and
not
self
.
invalid_block_ids
and
not
self
.
invalid_block_ids
)
)
...
...
vllm/v1/worker/kv_connector_model_runner_mixin.py
View file @
f4417f84
...
@@ -22,7 +22,6 @@ from vllm.distributed.kv_transfer import (
...
@@ -22,7 +22,6 @@ from vllm.distributed.kv_transfer import (
has_kv_transfer_group
,
has_kv_transfer_group
,
)
)
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
KVConnectorStats
from
vllm.forward_context
import
get_forward_context
,
set_forward_context
from
vllm.forward_context
import
get_forward_context
,
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
KVCacheConfig
...
@@ -138,16 +137,10 @@ class KVConnectorModelRunnerMixin:
...
@@ -138,16 +137,10 @@ class KVConnectorModelRunnerMixin:
)
)
output
.
invalid_block_ids
=
kv_connector
.
get_block_ids_with_load_errors
()
output
.
invalid_block_ids
=
kv_connector
.
get_block_ids_with_load_errors
()
output
.
kv_connector_stats
=
(
output
.
kv_connector_stats
=
kv_connector
.
get_kv_connector_stats
()
KVConnectorModelRunnerMixin
.
get_kv_connector_stats
()
output
.
kv_cache_events
=
kv_connector
.
get_kv_connector_kv_cache_events
()
)
kv_connector
.
clear_connector_metadata
()
@
staticmethod
kv_connector
.
clear_connector_metadata
()
def
get_kv_connector_stats
()
->
KVConnectorStats
|
None
:
if
has_kv_transfer_group
():
return
get_kv_transfer_group
().
get_kv_connector_stats
()
return
None
@
staticmethod
@
staticmethod
def
use_uniform_kv_cache
(
def
use_uniform_kv_cache
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment