Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4c16ba61
Unverified
Commit
4c16ba61
authored
Jan 11, 2026
by
Or Ozeri
Committed by
GitHub
Jan 11, 2026
Browse files
[KVConnector] OffloadingConnector: Fix bug in handling of preemptions (#29870)
Signed-off-by:
Or Ozeri
<
oro@il.ibm.com
>
parent
bde57ab2
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
255 additions
and
64 deletions
+255
-64
tests/v1/kv_connector/unit/test_offloading_connector.py
tests/v1/kv_connector/unit/test_offloading_connector.py
+175
-64
tests/v1/kv_offload/test_worker.py
tests/v1/kv_offload/test_worker.py
+12
-0
vllm/distributed/kv_transfer/kv_connector/v1/base.py
vllm/distributed/kv_transfer/kv_connector/v1/base.py
+10
-0
vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
...buted/kv_transfer/kv_connector/v1/offloading_connector.py
+24
-0
vllm/v1/kv_offload/worker/cpu_gpu.py
vllm/v1/kv_offload/worker/cpu_gpu.py
+10
-0
vllm/v1/kv_offload/worker/worker.py
vllm/v1/kv_offload/worker/worker.py
+19
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+5
-0
No files found.
tests/v1/kv_connector/unit/test_offloading_connector.py
View file @
4c16ba61
...
@@ -64,8 +64,11 @@ class MockLoadStoreSpec(LoadStoreSpec):
...
@@ -64,8 +64,11 @@ class MockLoadStoreSpec(LoadStoreSpec):
class
MockOffloadingHandler
(
OffloadingHandler
):
class
MockOffloadingHandler
(
OffloadingHandler
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
transfer_specs
:
dict
[
int
,
TransferSpec
]
=
{}
self
.
completed_transfers
:
list
[
TransferResult
]
=
[]
self
.
completed_transfers
:
list
[
TransferResult
]
=
[]
self
.
completed_specs
:
list
[
TransferSpec
]
=
[]
self
.
waiting_jobs
:
set
[
int
]
=
set
()
self
.
completed_jobs
:
list
[
int
]
=
[]
self
.
flushed_jobs
:
set
[
int
]
=
set
()
def
get_finished
(
self
)
->
list
[
TransferResult
]:
def
get_finished
(
self
)
->
list
[
TransferResult
]:
finished
=
self
.
completed_transfers
finished
=
self
.
completed_transfers
...
@@ -73,10 +76,21 @@ class MockOffloadingHandler(OffloadingHandler):
...
@@ -73,10 +76,21 @@ class MockOffloadingHandler(OffloadingHandler):
return
finished
return
finished
def
transfer_async
(
self
,
job_id
:
int
,
spec
:
TransferSpec
)
->
bool
:
def
transfer_async
(
self
,
job_id
:
int
,
spec
:
TransferSpec
)
->
bool
:
self
.
completed_specs
.
append
(
spec
)
self
.
transfer_specs
[
job_id
]
=
spec
self
.
completed_transfers
.
append
((
job_id
,
True
)
)
self
.
waiting_jobs
.
add
(
job_id
)
return
True
return
True
def
complete_jobs
(
self
,
job_ids
:
set
[
int
])
->
None
:
for
job_id
in
job_ids
:
if
job_id
in
self
.
waiting_jobs
:
self
.
waiting_jobs
.
remove
(
job_id
)
self
.
completed_jobs
.
append
(
job_id
)
self
.
completed_transfers
.
append
((
job_id
,
True
))
def
wait
(
self
,
job_ids
:
set
[
int
])
->
None
:
self
.
flushed_jobs
|=
job_ids
self
.
complete_jobs
(
job_ids
)
class
MockOffloadingSpec
(
OffloadingSpec
):
class
MockOffloadingSpec
(
OffloadingSpec
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
):
...
@@ -98,9 +112,22 @@ class MockOffloadingSpec(OffloadingSpec):
...
@@ -98,9 +112,22 @@ class MockOffloadingSpec(OffloadingSpec):
yield
GPULoadStoreSpec
,
MockLoadStoreSpec
,
self
.
handler
yield
GPULoadStoreSpec
,
MockLoadStoreSpec
,
self
.
handler
yield
MockLoadStoreSpec
,
GPULoadStoreSpec
,
self
.
handler
yield
MockLoadStoreSpec
,
GPULoadStoreSpec
,
self
.
handler
def
complete_transfers
(
self
):
self
.
handler
.
complete_jobs
(
self
.
handler
.
waiting_jobs
.
copy
())
def
get_completed_transfers
(
self
)
->
list
[
TransferSpec
]:
def
get_completed_transfers
(
self
)
->
list
[
TransferSpec
]:
specs
=
self
.
handler
.
completed_specs
specs
=
[
self
.
handler
.
completed_specs
=
[]
self
.
handler
.
transfer_specs
[
job_id
]
for
job_id
in
self
.
handler
.
completed_jobs
]
self
.
handler
.
completed_jobs
.
clear
()
return
specs
def
get_flushed_transfers
(
self
):
specs
=
[
self
.
handler
.
transfer_specs
[
job_id
]
for
job_id
in
self
.
handler
.
flushed_jobs
]
self
.
handler
.
flushed_jobs
.
clear
()
return
specs
return
specs
...
@@ -170,12 +197,9 @@ class RequestRunner:
...
@@ -170,12 +197,9 @@ class RequestRunner:
# mapping (offloading address) -> gpu_block_index
# mapping (offloading address) -> gpu_block_index
self
.
offloaded
:
dict
[
Any
,
int
]
=
{}
self
.
offloaded
:
dict
[
Any
,
int
]
=
{}
self
.
pending_loads_count
:
int
=
0
self
.
pending_stores_count
:
int
=
0
self
.
unsubmitted_stores_count
=
0
self
.
completed_loads
:
list
[
TransferSummary
]
=
[]
self
.
completed_loads
:
list
[
TransferSummary
]
=
[]
self
.
completed_stores
:
list
[
TransferSummary
]
=
[]
self
.
completed_stores
:
list
[
TransferSummary
]
=
[]
self
.
flushed_gpu_block_indexes
:
set
[
int
]
=
set
()
# maps {block_id: block_offset}
# maps {block_id: block_offset}
self
.
gpu_block_index
:
dict
[
int
,
int
]
=
{}
self
.
gpu_block_index
:
dict
[
int
,
int
]
=
{}
...
@@ -202,10 +226,18 @@ class RequestRunner:
...
@@ -202,10 +226,18 @@ class RequestRunner:
self
.
scheduler
.
add_request
(
req
)
self
.
scheduler
.
add_request
(
req
)
def
_wait_for_transfers
(
self
):
def
_parse_transfers
(
self
):
for
transfer_spec
in
self
.
offloading_spec
.
get_flushed_transfers
():
src_spec
,
dst_spec
=
transfer_spec
assert
isinstance
(
src_spec
,
GPULoadStoreSpec
)
for
block_id
in
src_spec
.
block_ids
:
self
.
flushed_gpu_block_indexes
.
add
(
self
.
gpu_block_index
[
block_id
.
item
()]
)
block_size_factor
=
self
.
offloaded_block_size
//
self
.
gpu_block_size
block_size_factor
=
self
.
offloaded_block_size
//
self
.
gpu_block_size
while
self
.
pending_loads_count
or
self
.
pending_stores_count
:
for
transfer_spec
in
self
.
offloading_spec
.
get_completed_transfers
():
for
transfer_spec
in
self
.
offloading_spec
.
get_completed_transfers
():
src_spec
,
dst_spec
=
transfer_spec
src_spec
,
dst_spec
=
transfer_spec
...
@@ -237,7 +269,6 @@ class RequestRunner:
...
@@ -237,7 +269,6 @@ class RequestRunner:
self
.
completed_stores
.
append
(
self
.
completed_stores
.
append
(
TransferSummary
(
gpu_block_indices
,
offload_addresses
)
TransferSummary
(
gpu_block_indices
,
offload_addresses
)
)
)
self
.
pending_stores_count
-=
1
else
:
else
:
remainder_sub_block_count
=
len
(
offload_addresses
)
-
len
(
remainder_sub_block_count
=
len
(
offload_addresses
)
-
len
(
gpu_block_indices
gpu_block_indices
...
@@ -249,7 +280,6 @@ class RequestRunner:
...
@@ -249,7 +280,6 @@ class RequestRunner:
self
.
completed_loads
.
append
(
self
.
completed_loads
.
append
(
TransferSummary
(
gpu_block_indices
,
offload_addresses
)
TransferSummary
(
gpu_block_indices
,
offload_addresses
)
)
)
self
.
pending_loads_count
-=
1
def
_update_gpu_block_idx
(
self
):
def
_update_gpu_block_idx
(
self
):
for
blocks
in
self
.
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
for
blocks
in
self
.
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
...
@@ -258,18 +288,19 @@ class RequestRunner:
...
@@ -258,18 +288,19 @@ class RequestRunner:
for
block_idx
,
block
in
enumerate
(
blocks
):
for
block_idx
,
block
in
enumerate
(
blocks
):
self
.
gpu_block_index
[
block
.
block_id
]
=
block_idx
self
.
gpu_block_index
[
block
.
block_id
]
=
block_idx
def
_run
(
self
,
decoded_tokens
:
list
[
int
]):
def
_run
(
self
,
decoded_tokens
:
list
[
int
]
,
complete_transfers
:
bool
):
"""
"""
Runs multiple engine (scheduler + worker) steps.
Runs multiple engine (scheduler + worker) steps.
Assumes a single request is running.
Assumes a single request is running.
Args:
Args:
decoded_tokens: the tokens to yield at each step.
decoded_tokens: the tokens to yield at each step.
complete_transfers: complete transfers immediately
"""
"""
tokens_iter
=
iter
(
decoded_tokens
)
tokens_iter
=
iter
(
decoded_tokens
)
token_id
=
next
(
tokens_iter
,
None
)
token_id
=
next
(
tokens_iter
,
None
)
while
token_id
is
not
Non
e
:
while
Tru
e
:
assert
self
.
scheduler
.
requests
assert
self
.
scheduler
.
requests
scheduler_output
=
self
.
scheduler
.
schedule
()
scheduler_output
=
self
.
scheduler
.
schedule
()
...
@@ -279,10 +310,10 @@ class RequestRunner:
...
@@ -279,10 +310,10 @@ class RequestRunner:
assert
kv_connector_metadata
is
not
None
assert
kv_connector_metadata
is
not
None
assert
isinstance
(
kv_connector_metadata
,
OffloadingConnectorMetadata
)
assert
isinstance
(
kv_connector_metadata
,
OffloadingConnectorMetadata
)
self
.
pending_loads_count
+=
len
(
kv_connector_metadata
.
reqs_to_load
)
if
scheduler_output
.
preempted_req_ids
:
self
.
worker_connector
.
handle_preemptions
(
self
.
pending_stores_count
+=
self
.
unsubmitted_stores_count
scheduler_output
.
preempted_req_ids
self
.
unsubmitted_stores_count
=
len
(
kv_connector_metadata
.
reqs_to_store
)
)
self
.
worker_connector
.
bind_connector_metadata
(
kv_connector_metadata
)
self
.
worker_connector
.
bind_connector_metadata
(
kv_connector_metadata
)
self
.
worker_connector
.
start_load_kv
(
self
.
_dummy_ctx
)
self
.
worker_connector
.
start_load_kv
(
self
.
_dummy_ctx
)
...
@@ -290,6 +321,9 @@ class RequestRunner:
...
@@ -290,6 +321,9 @@ class RequestRunner:
if
scheduler_output
.
total_num_scheduled_tokens
>
0
:
if
scheduler_output
.
total_num_scheduled_tokens
>
0
:
self
.
worker_connector
.
wait_for_save
()
self
.
worker_connector
.
wait_for_save
()
if
complete_transfers
:
self
.
offloading_spec
.
complete_transfers
()
finished_sending
,
finished_recving
=
self
.
worker_connector
.
get_finished
(
finished_sending
,
finished_recving
=
self
.
worker_connector
.
get_finished
(
scheduler_output
.
finished_req_ids
scheduler_output
.
finished_req_ids
)
)
...
@@ -300,7 +334,7 @@ class RequestRunner:
...
@@ -300,7 +334,7 @@ class RequestRunner:
reqs
=
self
.
scheduler
.
running
,
reqs
=
self
.
scheduler
.
running
,
finished_sending
=
finished_sending
,
finished_sending
=
finished_sending
,
finished_recving
=
finished_recving
,
finished_recving
=
finished_recving
,
token_id
=
token_id
,
token_id
=
token_id
or
0
,
)
)
if
self
.
scheduler
.
running
:
if
self
.
scheduler
.
running
:
...
@@ -308,7 +342,10 @@ class RequestRunner:
...
@@ -308,7 +342,10 @@ class RequestRunner:
self
.
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
self
.
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
self
.
_wait_for_transfers
()
if
token_id
is
None
:
break
self
.
_parse_transfers
()
# run one more step to update finished stored
# run one more step to update finished stored
if
EOS_TOKEN_ID
in
decoded_tokens
:
if
EOS_TOKEN_ID
in
decoded_tokens
:
...
@@ -333,8 +370,10 @@ class RequestRunner:
...
@@ -333,8 +370,10 @@ class RequestRunner:
def
run
(
def
run
(
self
,
self
,
decoded_tokens
:
list
[
int
],
decoded_tokens
:
list
[
int
],
complete_transfers
:
bool
=
True
,
expected_stored_gpu_block_indexes
:
tuple
[
int
,
...]
=
(),
expected_stored_gpu_block_indexes
:
tuple
[
int
,
...]
=
(),
expected_loaded_gpu_block_indexes
:
tuple
[
int
,
...]
=
(),
expected_loaded_gpu_block_indexes
:
tuple
[
int
,
...]
=
(),
expected_flushed_gpu_block_indexes
:
tuple
[
int
,
...]
=
(),
):
):
"""
"""
Runs multiple engine (scheduler + worker) steps.
Runs multiple engine (scheduler + worker) steps.
...
@@ -342,14 +381,17 @@ class RequestRunner:
...
@@ -342,14 +381,17 @@ class RequestRunner:
Args:
Args:
decoded_tokens: the tokens to yield at each step.
decoded_tokens: the tokens to yield at each step.
complete_transfers: complete transfers immediately
expected_stored_gpu_block_indexes: GPU block indexes
expected_stored_gpu_block_indexes: GPU block indexes
that are expected to be written during the run.
that are expected to be written during the run.
expected_loaded_gpu_block_indexes: GPU block indexes
expected_loaded_gpu_block_indexes: GPU block indexes
that are expected to be loaded during the run.
that are expected to be loaded during the run.
expected_flushed_gpu_block_indexes: GPU block indexes
that are expected to be flushed during the run.
"""
"""
self
.
manager
.
reset_mock
()
self
.
manager
.
reset_mock
()
self
.
_run
(
decoded_tokens
)
self
.
_run
(
decoded_tokens
,
complete_transfers
)
loaded_gpu_block_indexes
:
set
[
int
]
=
set
()
loaded_gpu_block_indexes
:
set
[
int
]
=
set
()
for
transfer
in
self
.
completed_loads
:
for
transfer
in
self
.
completed_loads
:
...
@@ -373,6 +415,9 @@ class RequestRunner:
...
@@ -373,6 +415,9 @@ class RequestRunner:
assert
set
(
expected_stored_gpu_block_indexes
)
==
stored_gpu_block_indexes
assert
set
(
expected_stored_gpu_block_indexes
)
==
stored_gpu_block_indexes
self
.
completed_stores
.
clear
()
self
.
completed_stores
.
clear
()
assert
set
(
expected_flushed_gpu_block_indexes
)
==
self
.
flushed_gpu_block_indexes
self
.
flushed_gpu_block_indexes
.
clear
()
@
pytest
.
fixture
@
pytest
.
fixture
def
request_runner
():
def
request_runner
():
...
@@ -539,3 +584,69 @@ def test_offloading_connector(request_runner):
...
@@ -539,3 +584,69 @@ def test_offloading_connector(request_runner):
assert
isinstance
(
event
,
BlockRemoved
)
assert
isinstance
(
event
,
BlockRemoved
)
assert
event
.
block_hashes
==
to_hashes
([
4
,
5
,
6
])
assert
event
.
block_hashes
==
to_hashes
([
4
,
5
,
6
])
assert
event
.
medium
==
"B"
assert
event
.
medium
==
"B"
def
test_request_preemption
(
request_runner
):
offloaded_block_size
=
12
gpu_block_size
=
4
num_gpu_blocks
=
100
runner
=
request_runner
(
offloaded_block_size
=
offloaded_block_size
,
gpu_block_size
=
gpu_block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
)
free_block_queue
=
runner
.
scheduler
.
kv_cache_manager
.
block_pool
.
free_block_queue
num_free_blocks_empty
=
free_block_queue
.
num_free_blocks
# 2 blocks, store all, without flushing
# blocks = [0, 1, 2], [3, 4, 5]
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
*
2
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
block_hashes
)
)
runner
.
run
(
decoded_tokens
=
[
0
],
complete_transfers
=
False
,
)
# decode 2 more blocks - 1 gpu block, storing [6, 7, 8] (no flush)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
block_hashes
)
)
runner
.
run
(
decoded_tokens
=
[
0
]
*
(
2
*
offloaded_block_size
-
gpu_block_size
),
complete_transfers
=
False
,
)
# simulate KV cache running out of space
free_block_queue
.
num_free_blocks
=
0
# request should be preempted now
runner
.
run
(
decoded_tokens
=
[],
complete_transfers
=
False
,
expected_flushed_gpu_block_indexes
=
(
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
),
expected_stored_gpu_block_indexes
=
(
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
),
)
# restore KV cache space and reset GPU prefix cache
free_block_queue
.
num_free_blocks
=
num_free_blocks_empty
runner
.
scheduler
.
reset_prefix_cache
()
# request should now return from preemption
# re-load [0, ..., 8] from the CPU and store [9, 10, 11]
runner
.
manager
.
lookup
.
return_value
=
3
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
block_hashes
)
)
runner
.
run
(
decoded_tokens
=
[
0
]
*
gpu_block_size
,
expected_loaded_gpu_block_indexes
=
(
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
),
)
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_stored_gpu_block_indexes
=
(
9
,
10
,
11
),
)
tests/v1/kv_offload/test_worker.py
View file @
4c16ba61
...
@@ -63,6 +63,12 @@ class OffloadingHandler1To2(OffloadingHandler):
...
@@ -63,6 +63,12 @@ class OffloadingHandler1To2(OffloadingHandler):
del
self
.
transfers
[
job_id
]
del
self
.
transfers
[
job_id
]
return
finished
return
finished
def
wait
(
self
,
job_ids
:
set
[
int
])
->
None
:
for
job_id
in
job_ids
:
spec
=
self
.
transfers
.
get
(
job_id
)
if
spec
:
assert
spec
.
finished
class
OffloadingHandler2To1
(
OffloadingHandler
):
class
OffloadingHandler2To1
(
OffloadingHandler
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -84,6 +90,12 @@ class OffloadingHandler2To1(OffloadingHandler):
...
@@ -84,6 +90,12 @@ class OffloadingHandler2To1(OffloadingHandler):
del
self
.
transfers
[
job_id
]
del
self
.
transfers
[
job_id
]
return
finished
return
finished
def
wait
(
self
,
job_ids
:
set
[
int
])
->
None
:
for
job_id
in
job_ids
:
spec
=
self
.
transfers
.
get
(
job_id
)
if
spec
:
assert
spec
.
finished
def
test_offloading_worker
():
def
test_offloading_worker
():
"""
"""
...
...
vllm/distributed/kv_transfer/kv_connector/v1/base.py
View file @
4c16ba61
...
@@ -25,6 +25,9 @@ The class provides the following primitives:
...
@@ -25,6 +25,9 @@ The class provides the following primitives:
Worker-side: runs in each worker, loads/saves KV cache to/from
Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata.
the Connector based on the metadata.
handle_preemptions() - called if there are preempted requests,
before their blocks are overwritten
start_load_kv() - starts loading all KVs (maybe async)
start_load_kv() - starts loading all KVs (maybe async)
wait_for_layer_load() - blocks until layer i load is done
wait_for_layer_load() - blocks until layer i load is done
...
@@ -262,6 +265,13 @@ class KVConnectorBase_V1(ABC):
...
@@ -262,6 +265,13 @@ class KVConnectorBase_V1(ABC):
"""
"""
return
return
def
handle_preemptions
(
self
,
preempted_req_ids
:
set
[
str
]):
"""
Handle preempted requests BEFORE their blocks are overwritten.
Needed for connectors which use async saves (e.g., OffloadingConnector)
"""
return
@
abstractmethod
@
abstractmethod
def
start_load_kv
(
self
,
forward_context
:
"ForwardContext"
,
**
kwargs
:
Any
)
->
None
:
def
start_load_kv
(
self
,
forward_context
:
"ForwardContext"
,
**
kwargs
:
Any
)
->
None
:
"""
"""
...
...
vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
View file @
4c16ba61
...
@@ -75,6 +75,10 @@ class OffloadingConnector(KVConnectorBase_V1):
...
@@ -75,6 +75,10 @@ class OffloadingConnector(KVConnectorBase_V1):
assert
self
.
connector_worker
is
not
None
assert
self
.
connector_worker
is
not
None
self
.
connector_worker
.
register_cross_layers_kv_cache
(
kv_cache
,
attn_backend
)
self
.
connector_worker
.
register_cross_layers_kv_cache
(
kv_cache
,
attn_backend
)
def
handle_preemptions
(
self
,
preempted_req_ids
:
set
[
str
]):
assert
self
.
connector_worker
is
not
None
self
.
connector_worker
.
handle_preemptions
(
preempted_req_ids
)
def
start_load_kv
(
self
,
forward_context
:
"ForwardContext"
,
**
kwargs
)
->
None
:
def
start_load_kv
(
self
,
forward_context
:
"ForwardContext"
,
**
kwargs
)
->
None
:
assert
self
.
connector_worker
is
not
None
assert
self
.
connector_worker
is
not
None
assert
isinstance
(
self
.
_connector_metadata
,
OffloadingConnectorMetadata
)
assert
isinstance
(
self
.
_connector_metadata
,
OffloadingConnectorMetadata
)
...
@@ -348,6 +352,15 @@ class OffloadingConnectorScheduler:
...
@@ -348,6 +352,15 @@ class OffloadingConnectorScheduler:
reqs_to_store
=
self
.
_get_reqs_to_store
(
scheduler_output
),
reqs_to_store
=
self
.
_get_reqs_to_store
(
scheduler_output
),
)
)
self
.
_reqs_to_load
=
{}
self
.
_reqs_to_load
=
{}
# NOTE (orozery): we should move this logic to update_connector_output
# once KVConnectorOutput allows us to report completed transfers
for
req_id
in
scheduler_output
.
preempted_req_ids
or
():
block_hashes
=
self
.
_reqs_being_stored
.
get
(
req_id
)
if
block_hashes
:
self
.
manager
.
complete_store
(
block_hashes
)
block_hashes
.
clear
()
return
meta
return
meta
def
update_connector_output
(
self
,
connector_output
:
KVConnectorOutput
):
def
update_connector_output
(
self
,
connector_output
:
KVConnectorOutput
):
...
@@ -466,6 +479,17 @@ class OffloadingConnectorWorker:
...
@@ -466,6 +479,17 @@ class OffloadingConnectorWorker:
attn_backends
=
{
cross_layer_name
:
attn_backend
}
attn_backends
=
{
cross_layer_name
:
attn_backend
}
self
.
_register_handlers
(
kv_caches
,
attn_backends
)
self
.
_register_handlers
(
kv_caches
,
attn_backends
)
def
handle_preemptions
(
self
,
preempted_req_ids
:
set
[
str
]):
for
job_id
,
transfer_spec
in
self
.
_unsubmitted_store_jobs
:
success
=
self
.
worker
.
transfer_async
(
job_id
,
transfer_spec
)
assert
success
self
.
_unsubmitted_store_jobs
.
clear
()
for
req_id
in
preempted_req_ids
:
job_ids
=
self
.
_store_jobs
.
get
(
req_id
)
if
job_ids
:
self
.
worker
.
wait
(
job_ids
)
def
start_kv_transfers
(
self
,
metadata
:
OffloadingConnectorMetadata
):
def
start_kv_transfers
(
self
,
metadata
:
OffloadingConnectorMetadata
):
for
job_id
,
transfer_spec
in
self
.
_unsubmitted_store_jobs
:
for
job_id
,
transfer_spec
in
self
.
_unsubmitted_store_jobs
:
success
=
self
.
worker
.
transfer_async
(
job_id
,
transfer_spec
)
success
=
self
.
worker
.
transfer_async
(
job_id
,
transfer_spec
)
...
...
vllm/v1/kv_offload/worker/cpu_gpu.py
View file @
4c16ba61
...
@@ -96,6 +96,8 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
...
@@ -96,6 +96,8 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
assert
len
(
src_tensors
)
>
0
assert
len
(
src_tensors
)
>
0
self
.
gpu_to_cpu
:
bool
=
self
.
src_tensors
[
0
].
is_cuda
self
.
gpu_to_cpu
:
bool
=
self
.
src_tensors
[
0
].
is_cuda
# job_id -> event
self
.
_transfer_events
:
dict
[
int
,
torch
.
Event
]
=
{}
# queue of transfers (job_id, stream, event)
# queue of transfers (job_id, stream, event)
self
.
_transfers
:
deque
[
tuple
[
int
,
torch
.
cuda
.
Stream
,
torch
.
Event
]]
=
deque
()
self
.
_transfers
:
deque
[
tuple
[
int
,
torch
.
cuda
.
Stream
,
torch
.
Event
]]
=
deque
()
# list of CUDA streams available for re-use
# list of CUDA streams available for re-use
...
@@ -152,6 +154,7 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
...
@@ -152,6 +154,7 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
ops
.
swap_blocks
(
src_tensor
,
dst_tensor
,
src_to_dst_tensor
)
ops
.
swap_blocks
(
src_tensor
,
dst_tensor
,
src_to_dst_tensor
)
event
.
record
(
stream
)
event
.
record
(
stream
)
self
.
_transfer_events
[
job_id
]
=
event
self
.
_transfers
.
append
((
job_id
,
stream
,
event
))
self
.
_transfers
.
append
((
job_id
,
stream
,
event
))
# success
# success
...
@@ -164,8 +167,15 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
...
@@ -164,8 +167,15 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
results
.
append
((
job_id
,
True
))
results
.
append
((
job_id
,
True
))
self
.
_stream_pool
.
append
(
stream
)
self
.
_stream_pool
.
append
(
stream
)
self
.
_event_pool
.
append
(
event
)
self
.
_event_pool
.
append
(
event
)
del
self
.
_transfer_events
[
job_id
]
return
results
return
results
def
wait
(
self
,
job_ids
:
set
[
int
]):
for
job_id
in
job_ids
:
event
=
self
.
_transfer_events
.
get
(
job_id
)
if
event
is
not
None
:
event
.
synchronize
()
class
CpuGpuOffloadingHandlers
:
class
CpuGpuOffloadingHandlers
:
def
__init__
(
def
__init__
(
...
...
vllm/v1/kv_offload/worker/worker.py
View file @
4c16ba61
...
@@ -53,6 +53,15 @@ class OffloadingHandler(ABC):
...
@@ -53,6 +53,15 @@ class OffloadingHandler(ABC):
"""
"""
pass
pass
@
abstractmethod
def
wait
(
self
,
job_ids
:
set
[
int
])
->
None
:
"""
Wait for jobs to finish (blocking).
Args:
job_ids: The set of job IDs to wait for.
"""
class
OffloadingWorker
:
class
OffloadingWorker
:
"""
"""
...
@@ -142,3 +151,13 @@ class OffloadingWorker:
...
@@ -142,3 +151,13 @@ class OffloadingWorker:
for
handler
in
self
.
handlers
:
for
handler
in
self
.
handlers
:
finished
.
extend
(
handler
.
get_finished
())
finished
.
extend
(
handler
.
get_finished
())
return
finished
return
finished
def
wait
(
self
,
job_ids
:
set
[
int
])
->
None
:
"""
Wait for jobs to finish (blocking).
Args:
job_ids: The set of job IDs to wait for.
"""
for
handler
in
self
.
handlers
:
handler
.
wait
(
job_ids
)
vllm/v1/worker/gpu_model_runner.py
View file @
4c16ba61
...
@@ -3112,6 +3112,11 @@ class GPUModelRunner(
...
@@ -3112,6 +3112,11 @@ class GPUModelRunner(
"after execute_model() returns None."
"after execute_model() returns None."
)
)
if
scheduler_output
.
preempted_req_ids
and
has_kv_transfer_group
():
get_kv_transfer_group
().
handle_preemptions
(
scheduler_output
.
preempted_req_ids
)
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
with
(
with
(
record_function_or_nullcontext
(
"gpu_model_runner: preprocess"
),
record_function_or_nullcontext
(
"gpu_model_runner: preprocess"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment