Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
14b4326b
Unverified
Commit
14b4326b
authored
Sep 01, 2025
by
Or Ozeri
Committed by
GitHub
Sep 01, 2025
Browse files
v1: Support KV events from connectors (#19737)
Signed-off-by:
Or Ozeri
<
oro@il.ibm.com
>
parent
752d2e1c
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
44 additions
and
3 deletions
+44
-3
examples/online_serving/kv_events_subscriber.py
examples/online_serving/kv_events_subscriber.py
+2
-0
vllm/distributed/kv_events.py
vllm/distributed/kv_events.py
+5
-0
vllm/distributed/kv_transfer/kv_connector/v1/base.py
vllm/distributed/kv_transfer/kv_connector/v1/base.py
+13
-0
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
...istributed/kv_transfer/kv_connector/v1/multi_connector.py
+6
-0
vllm/v1/core/block_pool.py
vllm/v1/core/block_pool.py
+6
-3
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+12
-0
No files found.
examples/online_serving/kv_events_subscriber.py
View file @
14b4326b
...
...
@@ -27,10 +27,12 @@ class BlockStored(KVCacheEvent):
token_ids
:
list
[
int
]
block_size
:
int
lora_id
:
Optional
[
int
]
medium
:
Optional
[
str
]
class
BlockRemoved
(
KVCacheEvent
):
block_hashes
:
list
[
int
]
medium
:
Optional
[
str
]
class
AllBlocksCleared
(
KVCacheEvent
):
...
...
vllm/distributed/kv_events.py
View file @
14b4326b
...
...
@@ -40,16 +40,21 @@ class KVCacheEvent(
"""Base class for all KV cache-related events"""
MEDIUM_GPU
=
"GPU"
class
BlockStored
(
KVCacheEvent
):
block_hashes
:
list
[
int
]
parent_block_hash
:
Optional
[
int
]
token_ids
:
list
[
int
]
block_size
:
int
lora_id
:
Optional
[
int
]
medium
:
Optional
[
str
]
class
BlockRemoved
(
KVCacheEvent
):
block_hashes
:
list
[
int
]
medium
:
Optional
[
str
]
class
AllBlocksCleared
(
KVCacheEvent
):
...
...
vllm/distributed/kv_transfer/kv_connector/v1/base.py
View file @
14b4326b
...
...
@@ -19,6 +19,8 @@ The class provides the following primitives:
Returns whether KV cache should be freed now or will be
freed asynchronously and optionally returns KV transfer
params.
take_events() - returns new KV events that were collected
by the connector since the last call.
Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata.
...
...
@@ -34,6 +36,7 @@ The class provides the following primitives:
import
enum
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Iterable
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Literal
,
Optional
import
torch
...
...
@@ -45,6 +48,7 @@ from vllm.v1.outputs import KVConnectorOutput
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.request
import
Request
...
...
@@ -313,6 +317,15 @@ class KVConnectorBase_V1(ABC):
"""
return
False
,
None
def
take_events
(
self
)
->
Iterable
[
"KVCacheEvent"
]:
"""
Take the KV cache events from the connector.
Yields:
New KV cache events since the last call.
"""
return
()
@
classmethod
def
get_required_kvcache_layout
(
cls
,
vllm_config
:
"VllmConfig"
)
->
Optional
[
str
]:
...
...
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
View file @
14b4326b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
from
vllm.config
import
KVTransferConfig
,
VllmConfig
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.distributed.kv_transfer.kv_connector.factory
import
(
KVConnectorFactory
)
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
...
...
@@ -208,6 +210,10 @@ class MultiConnector(KVConnectorBase_V1):
return
async_saves
>
0
,
kv_txfer_params
def
take_events
(
self
)
->
Iterable
[
KVCacheEvent
]:
for
c
in
self
.
_connectors
:
yield
from
c
.
take_events
()
@
classmethod
def
get_required_kvcache_layout
(
cls
,
vllm_config
:
"VllmConfig"
)
->
Optional
[
str
]:
...
...
vllm/v1/core/block_pool.py
View file @
14b4326b
...
...
@@ -4,8 +4,9 @@ from collections import defaultdict
from
collections.abc
import
Iterable
from
typing
import
Optional
from
vllm.distributed.kv_events
import
(
AllBlocksCleared
,
BlockRemoved
,
BlockStored
,
KVCacheEvent
)
from
vllm.distributed.kv_events
import
(
MEDIUM_GPU
,
AllBlocksCleared
,
BlockRemoved
,
BlockStored
,
KVCacheEvent
)
from
vllm.logger
import
init_logger
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
BlockHashWithGroupId
,
FreeKVCacheBlockQueue
,
KVCacheBlock
)
...
...
@@ -156,6 +157,7 @@ class BlockPool:
block_size
=
block_size
,
lora_id
=
request
.
lora_request
.
id
if
request
.
lora_request
else
None
,
medium
=
MEDIUM_GPU
,
))
def
get_new_blocks
(
self
,
num_blocks
:
int
)
->
list
[
KVCacheBlock
]:
...
...
@@ -218,7 +220,8 @@ class BlockPool:
# we disable hybrid kv cache manager when kv cache event is
# enabled, so there is only one group.
self
.
kv_event_queue
.
append
(
BlockRemoved
(
block_hashes
=
[
block_hash
.
get_hash_value
()]))
BlockRemoved
(
block_hashes
=
[
block_hash
.
get_hash_value
()],
medium
=
MEDIUM_GPU
))
return
True
def
touch
(
self
,
blocks
:
tuple
[
list
[
KVCacheBlock
],
...])
->
None
:
...
...
vllm/v1/core/sched/scheduler.py
View file @
14b4326b
...
...
@@ -589,7 +589,19 @@ class Scheduler(SchedulerInterface):
meta
=
self
.
connector
.
build_connector_meta
(
scheduler_output
)
scheduler_output
.
kv_connector_metadata
=
meta
# collect KV cache events from KV cache manager
events
=
self
.
kv_cache_manager
.
take_events
()
# collect KV cache events from connector
if
self
.
connector
is
not
None
:
connector_events
=
self
.
connector
.
take_events
()
if
connector_events
:
if
events
is
None
:
events
=
list
(
connector_events
)
else
:
events
.
extend
(
connector_events
)
# publish collected KV cache events
if
events
:
batch
=
KVEventBatch
(
ts
=
time
.
time
(),
events
=
events
)
self
.
kv_event_publisher
.
publish
(
batch
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment