Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
525f2eeb
Unverified
Commit
525f2eeb
authored
Mar 18, 2026
by
Or Ozeri
Committed by
GitHub
Mar 18, 2026
Browse files
[kv_offload+HMA][6/N]: Split offloading_connector.py (#37405)
Signed-off-by:
Or Ozeri
<
oro@il.ibm.com
>
parent
918b7890
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
733 additions
and
658 deletions
+733
-658
tests/v1/kv_connector/unit/test_offloading_connector.py
tests/v1/kv_connector/unit/test_offloading_connector.py
+6
-2
vllm/distributed/kv_transfer/kv_connector/v1/offloading/__init__.py
...ibuted/kv_transfer/kv_connector/v1/offloading/__init__.py
+0
-0
vllm/distributed/kv_transfer/kv_connector/v1/offloading/common.py
...tributed/kv_transfer/kv_connector/v1/offloading/common.py
+15
-0
vllm/distributed/kv_transfer/kv_connector/v1/offloading/metrics.py
...ributed/kv_transfer/kv_connector/v1/offloading/metrics.py
+165
-0
vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
...buted/kv_transfer/kv_connector/v1/offloading/scheduler.py
+347
-0
vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py
...tributed/kv_transfer/kv_connector/v1/offloading/worker.py
+185
-0
vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
...buted/kv_transfer/kv_connector/v1/offloading_connector.py
+15
-656
No files found.
tests/v1/kv_connector/unit/test_offloading_connector.py
View file @
525f2eeb
...
...
@@ -13,11 +13,15 @@ from vllm import SamplingParams
from
vllm.config
import
KVTransferConfig
,
VllmConfig
from
vllm.distributed.kv_events
import
BlockRemoved
,
BlockStored
from
vllm.distributed.kv_transfer.kv_connector.v1
import
KVConnectorRole
from
vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector
import
(
OffloadingConnector
,
from
vllm.distributed.kv_transfer.kv_connector.v1.offloading.common
import
(
OffloadingConnectorMetadata
,
)
from
vllm.distributed.kv_transfer.kv_connector.v1.offloading.metrics
import
(
OffloadingConnectorStats
,
)
from
vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector
import
(
OffloadingConnector
,
)
from
vllm.forward_context
import
ForwardContext
from
vllm.utils.hashing
import
sha256
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionBackend
...
...
vllm/distributed/kv_transfer/kv_connector/v1/offloading/__init__.py
0 → 100644
View file @
525f2eeb
vllm/distributed/kv_transfer/kv_connector/v1/offloading/common.py
0 → 100644
View file @
525f2eeb
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
KVConnectorMetadata
from
vllm.v1.kv_offload.worker.worker
import
TransferSpec
ReqId
=
str
@
dataclass
class
OffloadingConnectorMetadata
(
KVConnectorMetadata
):
reqs_to_load
:
dict
[
ReqId
,
TransferSpec
]
reqs_to_store
:
dict
[
ReqId
,
TransferSpec
]
reqs_to_flush
:
set
[
str
]
|
None
=
None
vllm/distributed/kv_transfer/kv_connector/v1/offloading/metrics.py
0 → 100644
View file @
525f2eeb
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
Any
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorPromMetrics
,
KVConnectorStats
,
PromMetric
,
PromMetricT
,
)
from
vllm.logger
import
init_logger
from
vllm.v1.kv_offload.worker.worker
import
TransferType
logger
=
init_logger
(
__name__
)
@
dataclass
class
OffloadingOperationMetrics
:
op_size
:
int
op_time
:
float
@
dataclass
class
OffloadingConnectorStats
(
KVConnectorStats
):
def
__post_init__
(
self
):
if
not
self
.
data
:
# Empty container init, no data is passed in.
self
.
reset
()
def
reset
(
self
):
self
.
data
:
dict
[
str
,
list
[
OffloadingOperationMetrics
]]
=
{}
def
aggregate
(
self
,
other
:
KVConnectorStats
)
->
KVConnectorStats
:
if
not
other
.
is_empty
():
for
k
,
v
in
other
.
data
.
items
():
if
k
not
in
self
.
data
:
self
.
data
[
k
]
=
v
else
:
accumulator
=
self
.
data
[
k
]
assert
isinstance
(
accumulator
,
list
)
accumulator
.
extend
(
v
)
return
self
def
reduce
(
self
)
->
dict
[
str
,
int
|
float
]:
"""
Reduce the observations collected during a time interval to one or
more representative values (eg avg/median/sum of the series).
This is meant to be called by the logger to produce a summary of the
stats for the last time interval.
"""
return_dict
:
dict
[
str
,
int
|
float
]
=
{}
for
transfer_type
,
ops_list
in
self
.
data
.
items
():
assert
isinstance
(
ops_list
,
list
)
total_bytes
=
0
total_time
=
0.0
for
op
in
ops_list
:
assert
isinstance
(
op
,
dict
)
total_bytes
+=
op
[
"op_size"
]
total_time
+=
op
[
"op_time"
]
return_dict
[
f
"
{
transfer_type
}
_total_bytes"
]
=
total_bytes
return_dict
[
f
"
{
transfer_type
}
_total_time"
]
=
total_time
return
return_dict
def
is_empty
(
self
)
->
bool
:
return
not
self
.
data
def
record_transfer
(
self
,
num_bytes
:
int
,
time
:
float
,
transfer_type
:
TransferType
):
src
,
dst
=
transfer_type
transfer_type_key
=
src
+
"_to_"
+
dst
op
=
OffloadingOperationMetrics
(
num_bytes
,
time
)
if
transfer_type_key
in
self
.
data
:
self
.
data
[
transfer_type_key
].
append
(
op
)
else
:
self
.
data
[
transfer_type_key
]
=
[
op
]
class
OffloadPromMetrics
(
KVConnectorPromMetrics
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
metric_types
:
dict
[
type
[
PromMetric
],
type
[
PromMetricT
]],
labelnames
:
list
[
str
],
per_engine_labelvalues
:
dict
[
int
,
list
[
object
]],
):
super
().
__init__
(
vllm_config
,
metric_types
,
labelnames
,
per_engine_labelvalues
)
# (engine_idx, transfer_type) -> (metric with bounded labels)
self
.
histogram_transfer_size
:
dict
[
tuple
[
int
,
str
],
PromMetricT
]
=
{}
self
.
counter_kv_bytes
:
dict
[
tuple
[
int
,
str
],
PromMetricT
]
=
{}
self
.
counter_kv_transfer_time
:
dict
[
tuple
[
int
,
str
],
PromMetricT
]
=
{}
buckets
=
[
# In bytes
1e6
,
5e6
,
10e6
,
20e6
,
40e6
,
60e6
,
80e6
,
100e6
,
150e6
,
200e6
,
]
self
.
_counter_kv_bytes
=
self
.
_counter_cls
(
name
=
"vllm:kv_offload_total_bytes"
,
documentation
=
"Number of bytes offloaded by KV connector"
,
labelnames
=
labelnames
+
[
"transfer_type"
],
)
self
.
_counter_kv_transfer_time
=
self
.
_counter_cls
(
name
=
"vllm:kv_offload_total_time"
,
documentation
=
"Total time measured by all KV offloading operations"
,
labelnames
=
labelnames
+
[
"transfer_type"
],
)
self
.
_histogram_transfer_size
=
self
.
_histogram_cls
(
name
=
"vllm:kv_offload_size"
,
documentation
=
"Histogram of KV offload transfer size, in bytes."
,
buckets
=
buckets
[:],
labelnames
=
labelnames
+
[
"transfer_type"
],
)
def
observe
(
self
,
transfer_stats_data
:
dict
[
str
,
Any
],
engine_idx
:
int
=
0
):
"""
Observe transfer statistics from the new data structure.
transfer_stats_data is expected to be a dict where:
- keys are transfer type strings (e.g., "cpu_to_gpu", "gpu_to_cpu")
- values are lists of OffloadingOperationMetrics objects
"""
for
transfer_type
,
ops
in
transfer_stats_data
.
items
():
# Cache:
if
(
engine_idx
,
transfer_type
)
not
in
self
.
histogram_transfer_size
:
self
.
histogram_transfer_size
[(
engine_idx
,
transfer_type
)]
=
(
self
.
_histogram_transfer_size
.
labels
(
*
(
self
.
per_engine_labelvalues
[
engine_idx
]
+
[
transfer_type
])
)
)
self
.
counter_kv_bytes
[(
engine_idx
,
transfer_type
)]
=
(
self
.
_counter_kv_bytes
.
labels
(
*
(
self
.
per_engine_labelvalues
[
engine_idx
]
+
[
transfer_type
])
)
)
self
.
counter_kv_transfer_time
[(
engine_idx
,
transfer_type
)]
=
(
self
.
_counter_kv_transfer_time
.
labels
(
*
(
self
.
per_engine_labelvalues
[
engine_idx
]
+
[
transfer_type
])
)
)
# Process ops:
assert
isinstance
(
ops
,
list
)
for
op
in
ops
:
# ops is a list of serialized OffloadingOperationMetrics
assert
isinstance
(
op
,
dict
)
# Observe size histogram
self
.
histogram_transfer_size
[(
engine_idx
,
transfer_type
)].
observe
(
op
[
"op_size"
]
)
# Increment byte and time counters
self
.
counter_kv_bytes
[(
engine_idx
,
transfer_type
)].
inc
(
op
[
"op_size"
])
self
.
counter_kv_transfer_time
[(
engine_idx
,
transfer_type
)].
inc
(
op
[
"op_time"
]
)
vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
0 → 100644
View file @
525f2eeb
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections
import
defaultdict
from
collections.abc
import
Iterable
from
itertools
import
islice
from
typing
import
Any
from
vllm.distributed.kv_events
import
BlockRemoved
,
BlockStored
,
KVCacheEvent
from
vllm.distributed.kv_transfer.kv_connector.utils
import
yield_req_data
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
KVConnectorMetadata
from
vllm.distributed.kv_transfer.kv_connector.v1.offloading.common
import
(
OffloadingConnectorMetadata
,
ReqId
,
)
from
vllm.logger
import
init_logger
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.core.kv_cache_utils
import
BlockHash
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_offload.abstract
import
OffloadingManager
from
vllm.v1.kv_offload.mediums
import
GPULoadStoreSpec
from
vllm.v1.kv_offload.spec
import
OffloadingSpec
from
vllm.v1.kv_offload.worker.worker
import
TransferSpec
from
vllm.v1.outputs
import
KVConnectorOutput
from
vllm.v1.request
import
Request
logger
=
init_logger
(
__name__
)
class
OffloadingConnectorScheduler
:
"""Implementation of Scheduler side methods"""
def
__init__
(
self
,
spec
:
OffloadingSpec
):
assert
len
(
spec
.
gpu_block_size
)
==
1
self
.
gpu_block_size
=
spec
.
gpu_block_size
[
0
]
self
.
offloaded_block_size
=
self
.
gpu_block_size
*
spec
.
block_size_factor
self
.
block_size_factor
=
spec
.
block_size_factor
self
.
manager
:
OffloadingManager
=
spec
.
get_manager
()
self
.
_requests
:
dict
[
ReqId
,
Request
]
=
{}
# list of GPU block IDs per request
self
.
_request_block_ids
:
dict
[
ReqId
,
list
[
int
]]
=
{}
# requests to load for the current scheduler step
self
.
_reqs_to_load
:
dict
[
ReqId
,
TransferSpec
]
=
{}
# request blocks are stored in order
# index of next block (of size offloaded_block_size) to offload
self
.
_next_stored_block_idx
:
dict
[
ReqId
,
int
]
=
{}
# if GPU prefix caching is enabled,
# track loaded blocks to avoid redundant loads
self
.
_blocks_being_loaded
:
set
[
BlockHash
]
|
None
=
(
set
()
if
spec
.
vllm_config
.
cache_config
.
enable_prefix_caching
else
None
)
# request ID -> set(block hashes being stored/load)
self
.
_reqs_being_stored
=
defaultdict
[
ReqId
,
set
[
BlockHash
]](
set
)
self
.
_reqs_being_loaded
=
defaultdict
[
ReqId
,
set
[
BlockHash
]](
set
)
def
_get_block_hashes
(
self
,
req
:
Request
,
start_idx
:
int
=
0
,
end_idx
:
int
|
None
=
None
,
)
->
Iterable
[
BlockHash
]:
return
islice
(
req
.
block_hashes
,
self
.
block_size_factor
*
start_idx
+
self
.
block_size_factor
-
1
,
self
.
block_size_factor
*
end_idx
if
end_idx
else
None
,
self
.
block_size_factor
,
)
def
get_num_new_matched_tokens
(
self
,
request
:
Request
,
num_computed_tokens
:
int
)
->
tuple
[
int
|
None
,
bool
]:
"""
Get number of new tokens that can be loaded beyond the
num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
A tuple with the following elements:
- The number of tokens that can be loaded beyond what is
already computed.
If None, it means that the connector needs more time to
determine the number of matched tokens, and the scheduler
should query for this request again later.
- `True` if tokens will be loaded asynchronously
(between scheduler steps).
"""
num_blocks
=
request
.
num_tokens
//
self
.
offloaded_block_size
assert
len
(
request
.
block_hashes
)
//
self
.
block_size_factor
==
num_blocks
block_hashes
=
self
.
_get_block_hashes
(
request
)
self
.
manager
.
touch
(
block_hashes
)
full_block_tokens
=
self
.
offloaded_block_size
*
num_blocks
if
full_block_tokens
-
num_computed_tokens
<
self
.
offloaded_block_size
:
# we can load less than a block, skip
return
0
,
False
start_block_idx
=
num_computed_tokens
//
self
.
offloaded_block_size
hits
=
self
.
manager
.
lookup
(
self
.
_get_block_hashes
(
request
,
start_idx
=
start_block_idx
)
)
if
hits
is
None
:
# indicates a lookup that should be tried later
return
None
,
False
if
hits
==
0
:
return
0
,
False
num_hit_tokens
=
(
self
.
offloaded_block_size
*
(
start_block_idx
+
hits
)
-
num_computed_tokens
)
logger
.
debug
(
"Request %s hit %s offloaded tokens after %s GPU hit tokens"
,
request
.
request_id
,
num_hit_tokens
,
num_computed_tokens
,
)
if
num_hit_tokens
<
self
.
offloaded_block_size
:
return
0
,
False
if
self
.
_blocks_being_loaded
:
block_hashes
=
self
.
_get_block_hashes
(
request
,
start_idx
=
start_block_idx
,
end_idx
=
start_block_idx
+
hits
)
if
any
(
block_hash
in
self
.
_blocks_being_loaded
for
block_hash
in
block_hashes
):
# hit blocks are being loaded, delay request
logger
.
debug
(
"Delaying request %s since some of its blocks are already"
" being loaded"
,
request
.
request_id
,
)
return
None
,
False
return
num_hit_tokens
,
True
def
update_state_after_alloc
(
self
,
request
:
Request
,
blocks
:
KVCacheBlocks
,
num_external_tokens
:
int
):
self
.
_requests
[
request
.
request_id
]
=
request
# the block ids are updated in _get_reqs_to_store
self
.
_request_block_ids
[
request
.
request_id
]
=
[]
if
num_external_tokens
==
0
:
return
block_groups
=
blocks
.
get_block_ids
()
block_ids
=
block_groups
[
0
]
num_computed_gpu_blocks
=
sum
(
block
.
block_hash
is
not
None
for
block
in
blocks
.
blocks
[
0
]
)
num_computed_tokens
=
num_computed_gpu_blocks
*
self
.
gpu_block_size
full_block_tokens
=
num_computed_tokens
+
num_external_tokens
assert
full_block_tokens
%
self
.
offloaded_block_size
==
0
num_pending_gpu_blocks
=
len
(
block_ids
)
-
num_computed_gpu_blocks
assert
num_external_tokens
==
num_pending_gpu_blocks
*
self
.
gpu_block_size
start_block_idx
=
num_computed_tokens
//
self
.
offloaded_block_size
num_blocks
=
full_block_tokens
//
self
.
offloaded_block_size
assert
len
(
request
.
block_hashes
)
//
self
.
block_size_factor
>=
num_blocks
block_hashes
=
self
.
_get_block_hashes
(
request
,
start_idx
=
start_block_idx
,
end_idx
=
num_blocks
)
src_spec
=
self
.
manager
.
prepare_load
(
block_hashes
)
dst_spec
=
GPULoadStoreSpec
(
block_ids
[
num_computed_gpu_blocks
:])
block_hashes
=
self
.
_get_block_hashes
(
request
,
start_idx
=
start_block_idx
,
end_idx
=
num_blocks
)
self
.
_reqs_to_load
[
request
.
request_id
]
=
(
src_spec
,
dst_spec
)
req_blocks_being_loaded
=
self
.
_reqs_being_loaded
[
request
.
request_id
]
req_blocks_being_loaded
.
update
(
block_hashes
)
self
.
_next_stored_block_idx
[
request
.
request_id
]
=
num_blocks
if
self
.
_blocks_being_loaded
is
not
None
:
self
.
_blocks_being_loaded
.
update
(
req_blocks_being_loaded
)
def
_get_reqs_to_store
(
self
,
scheduler_output
:
SchedulerOutput
):
reqs_to_store
:
dict
[
ReqId
,
TransferSpec
]
=
{}
# iterate over both new and cached requests
for
req_id
,
new_block_id_groups
,
preempted
in
yield_req_data
(
scheduler_output
):
if
preempted
:
self
.
_request_block_ids
[
req_id
]
=
[]
if
new_block_id_groups
:
new_block_ids
=
new_block_id_groups
[
0
]
self
.
_request_block_ids
[
req_id
]
+=
new_block_ids
block_ids
=
self
.
_request_block_ids
[
req_id
]
req
=
self
.
_requests
[
req_id
]
new_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
expected_tokens
=
req
.
num_computed_tokens
+
new_tokens
# with async scheduling, some tokens may be missing
total_tokens
=
min
(
expected_tokens
,
req
.
num_tokens
)
num_blocks
=
total_tokens
//
self
.
offloaded_block_size
start_block_idx
=
self
.
_next_stored_block_idx
.
get
(
req_id
,
0
)
num_new_blocks
=
num_blocks
-
start_block_idx
if
num_new_blocks
<=
0
:
continue
num_gpu_blocks
=
num_blocks
*
self
.
block_size_factor
assert
len
(
req
.
block_hashes
)
>=
num_gpu_blocks
new_block_hashes
=
self
.
_get_block_hashes
(
req
,
start_idx
=
start_block_idx
,
end_idx
=
num_blocks
)
store_output
=
self
.
manager
.
prepare_store
(
new_block_hashes
)
if
store_output
is
None
:
logger
.
warning
(
"Request %s: cannot store %s blocks"
,
req_id
,
num_new_blocks
)
continue
self
.
_next_stored_block_idx
[
req_id
]
=
num_blocks
if
not
store_output
.
block_hashes_to_store
:
continue
block_hashes_to_store
=
set
(
store_output
.
block_hashes_to_store
)
block_hashes
=
self
.
_get_block_hashes
(
req
,
end_idx
=
num_blocks
)
self
.
manager
.
touch
(
block_hashes
)
new_block_hashes
=
self
.
_get_block_hashes
(
req
,
start_idx
=
start_block_idx
,
end_idx
=
num_blocks
)
dst_spec
=
store_output
.
store_spec
src_block_ids
:
list
[
int
]
=
[]
for
idx
,
blk_hash
in
enumerate
(
new_block_hashes
):
if
blk_hash
not
in
block_hashes_to_store
:
continue
offloaded_block_idx
=
start_block_idx
+
idx
gpu_block_idx
=
offloaded_block_idx
*
self
.
block_size_factor
for
i
in
range
(
self
.
block_size_factor
):
src_block_ids
.
append
(
block_ids
[
gpu_block_idx
+
i
])
src_spec
=
GPULoadStoreSpec
(
src_block_ids
)
reqs_to_store
[
req_id
]
=
(
src_spec
,
dst_spec
)
self
.
_reqs_being_stored
[
req_id
]
|=
block_hashes_to_store
logger
.
debug
(
"Request %s offloading %s blocks starting from block #%d"
,
req_id
,
len
(
block_hashes_to_store
),
start_block_idx
,
)
return
reqs_to_store
def
build_connector_meta
(
self
,
scheduler_output
:
SchedulerOutput
)
->
KVConnectorMetadata
:
meta
=
OffloadingConnectorMetadata
(
reqs_to_load
=
self
.
_reqs_to_load
,
reqs_to_store
=
self
.
_get_reqs_to_store
(
scheduler_output
),
reqs_to_flush
=
scheduler_output
.
preempted_req_ids
,
)
self
.
_reqs_to_load
=
{}
# NOTE (orozery): we should move this logic to update_connector_output
# once KVConnectorOutput allows us to report completed transfers
for
req_id
in
scheduler_output
.
preempted_req_ids
or
():
block_hashes
=
self
.
_reqs_being_stored
.
get
(
req_id
)
if
block_hashes
:
self
.
manager
.
complete_store
(
block_hashes
)
block_hashes
.
clear
()
return
meta
def
update_connector_output
(
self
,
connector_output
:
KVConnectorOutput
):
"""
Update KVConnector state from worker-side connectors output.
Args:
connector_output (KVConnectorOutput): the worker-side
connectors output.
"""
for
req_id
in
connector_output
.
finished_sending
or
[]:
block_hashes
=
self
.
_reqs_being_stored
.
pop
(
req_id
,
None
)
if
block_hashes
:
self
.
manager
.
complete_store
(
block_hashes
)
for
req_id
in
connector_output
.
finished_recving
or
[]:
block_hashes
=
self
.
_reqs_being_loaded
.
pop
(
req_id
,
None
)
if
block_hashes
:
if
self
.
_blocks_being_loaded
:
self
.
_blocks_being_loaded
.
difference_update
(
block_hashes
)
self
.
manager
.
complete_load
(
block_hashes
)
def
request_finished
(
self
,
request
:
Request
,
block_ids
:
list
[
int
],
)
->
tuple
[
bool
,
dict
[
str
,
Any
]
|
None
]:
"""
Called when a request has finished, before its blocks are freed.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
req_id
=
request
.
request_id
self
.
_requests
.
pop
(
req_id
,
None
)
self
.
_request_block_ids
.
pop
(
req_id
,
None
)
# TODO(orozery): possibly kickoff offload for last block
# which may have been deferred due to async scheduling
self
.
_next_stored_block_idx
.
pop
(
req_id
,
None
)
request_being_stored
=
req_id
in
self
.
_reqs_being_stored
return
request_being_stored
,
None
def
take_events
(
self
)
->
Iterable
[
KVCacheEvent
]:
"""Take the KV cache events from the connector.
Returns:
A list of KV cache events.
"""
for
event
in
self
.
manager
.
take_events
():
if
event
.
removed
:
yield
BlockRemoved
(
block_hashes
=
event
.
block_hashes
,
medium
=
event
.
medium
)
else
:
yield
BlockStored
(
block_hashes
=
event
.
block_hashes
,
parent_block_hash
=
None
,
token_ids
=
[],
lora_id
=
None
,
block_size
=
event
.
block_size
,
medium
=
event
.
medium
,
lora_name
=
None
,
)
vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py
0 → 100644
View file @
525f2eeb
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections
import
defaultdict
import
torch
from
vllm.config
import
get_layers_from_vllm_config
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorStats
,
)
from
vllm.distributed.kv_transfer.kv_connector.v1.offloading.common
import
(
OffloadingConnectorMetadata
,
ReqId
,
)
from
vllm.distributed.kv_transfer.kv_connector.v1.offloading.metrics
import
(
OffloadingConnectorStats
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.v1.attention.backend
import
AttentionBackend
from
vllm.v1.kv_offload.spec
import
OffloadingSpec
from
vllm.v1.kv_offload.worker.worker
import
(
OffloadingWorker
,
TransferSpec
,
)
logger
=
init_logger
(
__name__
)
class
OffloadingConnectorWorker
:
"""Implementation of Worker side methods"""
def
__init__
(
self
,
spec
:
OffloadingSpec
):
self
.
spec
=
spec
self
.
worker
=
OffloadingWorker
()
self
.
_job_counter
=
0
self
.
kv_connector_stats
=
OffloadingConnectorStats
()
# req_id -> (job_id, store)
self
.
_jobs
:
dict
[
int
,
tuple
[
ReqId
,
bool
]]
=
{}
# req_id -> active job IDs
self
.
_load_job
:
dict
[
ReqId
,
int
]
=
{}
# req_id -> set(active job IDs)
self
.
_store_jobs
=
defaultdict
[
ReqId
,
set
[
int
]](
set
)
# list of store jobs pending submission (job_id, transfer_spec)
self
.
_unsubmitted_store_jobs
:
list
[
tuple
[
int
,
TransferSpec
]]
=
[]
self
.
_finished_reqs_waiting_for_store
:
set
[
ReqId
]
=
set
()
def
_generate_job_id
(
self
)
->
int
:
job_id
=
self
.
_job_counter
self
.
_job_counter
=
job_id
+
1
return
job_id
def
_register_handlers
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
],
attn_backends
:
dict
[
str
,
type
[
AttentionBackend
]],
):
for
src_cls
,
dst_cls
,
handler
in
self
.
spec
.
get_handlers
(
kv_caches
,
attn_backends
):
self
.
worker
.
register_handler
(
src_cls
,
dst_cls
,
handler
)
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]):
layer_names
=
list
(
kv_caches
.
keys
())
layers
=
get_layers_from_vllm_config
(
self
.
spec
.
vllm_config
,
AttentionLayerBase
,
# type: ignore[type-abstract]
layer_names
,
)
attn_backends
=
{
layer_name
:
layers
[
layer_name
].
get_attn_backend
()
for
layer_name
in
layer_names
}
self
.
_register_handlers
(
kv_caches
,
attn_backends
)
def
register_cross_layers_kv_cache
(
self
,
kv_cache
:
torch
.
Tensor
,
attn_backend
:
type
[
AttentionBackend
]
):
cross_layer_name
=
"ALL_LAYERS"
kv_caches
=
{
cross_layer_name
:
kv_cache
}
attn_backends
=
{
cross_layer_name
:
attn_backend
}
self
.
_register_handlers
(
kv_caches
,
attn_backends
)
def
handle_preemptions
(
self
,
kv_connector_metadata
:
OffloadingConnectorMetadata
):
for
job_id
,
transfer_spec
in
self
.
_unsubmitted_store_jobs
:
success
=
self
.
worker
.
transfer_async
(
job_id
,
transfer_spec
)
assert
success
self
.
_unsubmitted_store_jobs
.
clear
()
for
req_id
in
kv_connector_metadata
.
reqs_to_flush
or
():
job_ids
=
self
.
_store_jobs
.
get
(
req_id
)
if
job_ids
:
self
.
worker
.
wait
(
job_ids
)
def
start_kv_transfers
(
self
,
metadata
:
OffloadingConnectorMetadata
):
for
job_id
,
transfer_spec
in
self
.
_unsubmitted_store_jobs
:
success
=
self
.
worker
.
transfer_async
(
job_id
,
transfer_spec
)
assert
success
self
.
_unsubmitted_store_jobs
.
clear
()
for
req_id
,
transfer_spec
in
metadata
.
reqs_to_load
.
items
():
job_id
=
self
.
_generate_job_id
()
self
.
_jobs
[
job_id
]
=
(
req_id
,
False
)
assert
req_id
not
in
self
.
_load_job
self
.
_load_job
[
req_id
]
=
job_id
success
=
self
.
worker
.
transfer_async
(
job_id
,
transfer_spec
)
assert
success
def
prepare_store_kv
(
self
,
metadata
:
OffloadingConnectorMetadata
):
for
req_id
,
transfer_spec
in
metadata
.
reqs_to_store
.
items
():
job_id
=
self
.
_generate_job_id
()
self
.
_jobs
[
job_id
]
=
(
req_id
,
True
)
self
.
_store_jobs
[
req_id
].
add
(
job_id
)
# NOTE(orozery): defer the store to the beginning of the next engine step,
# so that offloading starts AFTER transfers related to token sampling,
# thereby avoiding delays to token generation due to offloading.
self
.
_unsubmitted_store_jobs
.
append
((
job_id
,
transfer_spec
))
def
get_finished
(
self
,
finished_req_ids
:
set
[
str
])
->
tuple
[
set
[
str
],
set
[
str
]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns a list of request IDs that finished loading or storing.
Returns:
ids of requests that have finished asynchronous transfer
tuple of (sending/saving ids, recving/loading ids).
"""
finished_sending
=
set
()
finished_recving
=
set
()
for
transfer_result
in
self
.
worker
.
get_finished
():
# we currently do not support job failures
job_id
=
transfer_result
.
job_id
assert
transfer_result
.
success
req_id
,
store
=
self
.
_jobs
.
pop
(
job_id
)
if
(
transfer_result
.
transfer_time
and
transfer_result
.
transfer_size
is
not
None
and
transfer_result
.
transfer_type
is
not
None
):
self
.
kv_connector_stats
.
record_transfer
(
num_bytes
=
transfer_result
.
transfer_size
,
time
=
transfer_result
.
transfer_time
,
transfer_type
=
transfer_result
.
transfer_type
,
)
if
store
:
req_jobs
=
self
.
_store_jobs
[
req_id
]
req_jobs
.
remove
(
job_id
)
if
req_jobs
:
continue
if
req_id
in
self
.
_finished_reqs_waiting_for_store
:
self
.
_finished_reqs_waiting_for_store
.
remove
(
req_id
)
finished_sending
.
add
(
req_id
)
del
self
.
_store_jobs
[
req_id
]
else
:
req_job
=
self
.
_load_job
[
req_id
]
assert
job_id
==
req_job
del
self
.
_load_job
[
req_id
]
finished_recving
.
add
(
req_id
)
for
req_id
in
finished_req_ids
:
pending_req_jobs
=
self
.
_store_jobs
.
get
(
req_id
)
if
pending_req_jobs
:
self
.
_finished_reqs_waiting_for_store
.
add
(
req_id
)
elif
pending_req_jobs
is
not
None
:
finished_sending
.
add
(
req_id
)
del
self
.
_store_jobs
[
req_id
]
return
finished_sending
,
finished_recving
def
get_kv_connector_stats
(
self
)
->
KVConnectorStats
|
None
:
"""
Get the KV transfer stats for the connector.
"""
if
self
.
kv_connector_stats
.
is_empty
():
return
None
# Clear stats for next iteration
kv_connector_stats
=
self
.
kv_connector_stats
self
.
kv_connector_stats
=
OffloadingConnectorStats
()
return
kv_connector_stats
vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
View file @
525f2eeb
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment