Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b18b417f
Unverified
Commit
b18b417f
authored
Jul 28, 2025
by
Kuntai Du
Committed by
GitHub
Jul 28, 2025
Browse files
Revert "[V1] Exception Handling when Loading KV Cache from Remote Store" (#21778)
Signed-off-by:
KuntaiDu
<
kuntai@uchicago.edu
>
parent
9ba1c88a
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
5 additions
and
229 deletions
+5
-229
tests/v1/kv_connector/kv_load_exception_handling/random_drop_connector.py
...ector/kv_load_exception_handling/random_drop_connector.py
+0
-120
tests/v1/kv_connector/kv_load_exception_handling/test.sh
tests/v1/kv_connector/kv_load_exception_handling/test.sh
+0
-16
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+1
-15
vllm/distributed/kv_transfer/kv_connector/v1/base.py
vllm/distributed/kv_transfer/kv_connector/v1/base.py
+0
-20
vllm/sequence.py
vllm/sequence.py
+0
-2
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+0
-30
vllm/v1/outputs.py
vllm/v1/outputs.py
+0
-3
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+2
-8
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+1
-3
vllm/v1/worker/kv_connector_model_runner_mixin.py
vllm/v1/worker/kv_connector_model_runner_mixin.py
+1
-12
No files found.
tests/v1/kv_connector/kv_load_exception_handling/random_drop_connector.py
deleted
100644 → 0
View file @
9ba1c88a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
logging
import
random
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
import
torch
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
)
from
vllm.v1.core.sched.output
import
SchedulerOutput
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.request
import
Request
logger
=
logging
.
getLogger
()
logging
.
basicConfig
(
level
=
logging
.
INFO
)
@
dataclass
class
RandomDropConnectorMetadata
(
KVConnectorMetadata
):
req_meta
:
dict
[
str
,
list
[
int
]]
class
RandomDropConnector
(
KVConnectorBase_V1
):
"""
A connector designed for fault tolerance testing by randomly dropping
kv data during the process of loading or receiving KV cache.
This class simulates real-world scenarios where requests or data
might be lost or timeout, allowing developers to test and validate the
system's ability to handle such failures.
Attributes:
finished_recving_kv_req_ids (set[str]): A set of request IDs that
have completed receiving KV cache data.
finished_loading_dict (dict[str, int]): A dictionary that tracks
the actual number of tokens loaded from the remote KV store
for each completed request. The keys are request IDs, and
the values are the corresponding token counts.
"""
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
):
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
)
self
.
failure_request
:
list
[
str
]
=
[]
self
.
_reqs_need_recv
:
dict
[
str
,
list
[
int
]]
=
{}
self
.
_finish_load
:
dict
[
str
,
int
]
=
{}
self
.
chunk_size
=
256
############################################################
# Scheduler Side Methods
############################################################
def
get_num_new_matched_tokens
(
self
,
request
:
"Request"
,
num_computed_tokens
:
int
)
->
tuple
[
int
,
bool
]:
if
request
.
request_id
in
self
.
failure_request
:
self
.
failure_request
.
remove
(
request
.
request_id
)
return
0
,
False
num_external_hit_tokens
=
request
.
num_prompt_tokens
-
1
logger
.
info
(
"request %s num_prompt_tokens %d num_external_hit_tokens %d"
,
request
.
request_id
,
request
.
num_prompt_tokens
,
num_external_hit_tokens
)
return
num_external_hit_tokens
,
True
def
update_state_after_alloc
(
self
,
request
:
"Request"
,
blocks
:
"KVCacheBlocks"
,
num_external_tokens
:
int
):
if
num_external_tokens
>
0
:
self
.
_reqs_need_recv
[
request
.
request_id
]
=
request
.
prompt_token_ids
[:
num_external_tokens
]
def
build_connector_meta
(
self
,
scheduler_output
:
SchedulerOutput
,
)
->
KVConnectorMetadata
:
req_meta
=
self
.
_reqs_need_recv
.
copy
()
self
.
_reqs_need_recv
.
clear
()
return
RandomDropConnectorMetadata
(
req_meta
)
def
add_failure_request
(
self
,
request
:
"Request"
):
self
.
failure_request
.
append
(
request
.
request_id
)
def
start_load_kv
(
self
,
forward_context
,
**
kwargs
)
->
None
:
for
request_id
,
hit_tokens
in
self
.
_get_connector_metadata
(
).
req_meta
.
items
():
num_actual_load_tokens
=
self
.
load_kv
(
request_id
,
hit_tokens
)
logger
.
info
(
"request %s hit_tokens %d num_actual_load_tokens %d"
,
request_id
,
len
(
hit_tokens
),
num_actual_load_tokens
)
self
.
_finish_load
[
request_id
]
=
num_actual_load_tokens
def
wait_for_layer_load
(
self
,
layer_name
:
str
)
->
None
:
pass
def
save_kv_layer
(
self
,
layer_name
:
str
,
kv_layer
:
torch
.
Tensor
,
attn_metadata
:
"AttentionMetadata"
,
**
kwargs
)
->
None
:
pass
def
wait_for_save
(
self
):
pass
def
load_kv
(
self
,
request_id
,
hit_tokens
):
num_actual_load_tokens
=
random
.
randint
(
0
,
len
(
hit_tokens
))
return
num_actual_load_tokens
def
get_finished_loading
(
self
)
->
dict
[
str
,
int
]:
if
not
self
.
_finish_load
:
return
{}
finished_loading
=
self
.
_finish_load
.
copy
()
self
.
_finish_load
.
clear
()
return
finished_loading
tests/v1/kv_connector/kv_load_exception_handling/test.sh
deleted
100644 → 0
View file @
9ba1c88a
#!/bin/bash
SCRIPT_DIR
=
$(
dirname
"
$(
readlink
-f
"
$0
"
)
"
)
export
PYTHONPATH
=
$PYTHONPATH
:
$SCRIPT_DIR
vllm serve DeepSeek-V2-Lite-Chat
\
--trust-remote-code
\
--served-model-name
vllm_cpu_offload
\
--max-model-len
32768
\
--no-enable-prefix-caching
\
--max-seq-len-to-capture
10000
\
--max-num-seqs
64
\
--gpu-memory-utilization
0.9
\
--host
0.0.0.0
\
-tp
2
\
--kv-transfer-config
'{"kv_connector":"RandomDropConnector","kv_role":"kv_both","kv_connector_module_path":"random_drop_connector"}'
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
b18b417f
...
...
@@ -139,27 +139,13 @@ class KVOutputAggregator:
finished_set
.
add
(
req_id
)
del
remaining_count_dict
[
req_id
]
def
update_finished_load_dict
(
worker_finished_loading_dict
:
dict
[
str
,
int
],
finished_loading_dict
:
dict
[
str
,
int
]):
for
req_id
,
num_actual_load_tokens
in
(
worker_finished_loading_dict
or
{}).
items
():
if
req_id
in
finished_loading_dict
:
finished_loading_dict
[
req_id
]
=
min
(
finished_loading_dict
[
req_id
],
num_actual_load_tokens
)
else
:
finished_loading_dict
[
req_id
]
=
num_actual_load_tokens
finished_sending
=
set
[
str
]()
finished_recving
=
set
[
str
]()
finished_loading_dict
:
dict
[
str
,
int
]
=
{}
for
output
in
outputs
:
update_finished_set
(
output
.
finished_sending
,
self
.
_send_remaining_count
,
finished_sending
)
update_finished_set
(
output
.
finished_recving
,
self
.
_recv_remaining_count
,
finished_recving
)
update_finished_load_dict
(
output
.
finished_loading_dict
,
finished_loading_dict
)
# select output of the worker specified by output_rank
output
=
outputs
[
output_rank
]
...
...
@@ -171,7 +157,7 @@ class KVOutputAggregator:
# send/recv
output
.
finished_sending
=
finished_sending
if
finished_sending
else
None
output
.
finished_recving
=
finished_recving
if
finished_recving
else
None
output
.
finished_loading_dict
=
finished_loading_dict
or
None
return
output
def
async_aggregate
(
self
,
...
...
vllm/distributed/kv_transfer/kv_connector/v1/base.py
View file @
b18b417f
...
...
@@ -28,9 +28,6 @@ The class provides the following primitives:
get_finished() - called with ids of finished requests, returns
ids of requests that have completed async sending/recving.
get_finished_loading() - called with scheduler outputs, returns
a dictionary that the keys are request IDs and the values are
the actual number of tokens loaded from the remote KV cache
"""
import
enum
...
...
@@ -222,23 +219,6 @@ class KVConnectorBase_V1(ABC):
"""
return
None
,
None
def
get_finished_loading
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
dict
[
str
,
int
]:
"""
Retrieves the actual number of tokens loaded for requests that have
completed the asynchronous loading process from the remote KV cache.
This function is used by the scheduler process (via the Executors)
to track the progress of requests and determine which requests have
successfully finished loading their KV cache data.
Returns:
A dictionary where the keys are request IDs and the values are the
corresponding number of tokens that have been successfully loaded
for each request.
"""
return
{}
# ==============================
# Scheduler-side methods
# ==============================
...
...
vllm/sequence.py
View file @
b18b417f
...
...
@@ -1167,8 +1167,6 @@ class IntermediateTensors:
# [req_ids]
finished_sending
:
Optional
[
set
[
str
]]
=
None
finished_recving
:
Optional
[
set
[
str
]]
=
None
#req_id -> num_actual_load_tokens
finished_loading_dict
:
Optional
[
dict
[
str
,
int
]]
=
None
def
__init__
(
self
,
tensors
):
# manually define this function, so that
...
...
vllm/v1/core/sched/scheduler.py
View file @
b18b417f
...
...
@@ -118,9 +118,6 @@ class Scheduler(SchedulerInterface):
# KV Connector: requests in process of async KV loading or recving
self
.
finished_recving_kv_req_ids
:
set
[
str
]
=
set
()
# The keys are request IDs, and the values are corresponding token
# count that have been successfully loaded from the remote KV store
self
.
finished_loading_dict
:
dict
[
str
,
int
]
=
{}
# Encoder-related.
# Calculate encoder cache size if applicable
...
...
@@ -1097,27 +1094,6 @@ class Scheduler(SchedulerInterface):
(
block_ids
,
)
=
self
.
kv_cache_manager
.
get_block_ids
(
request
.
request_id
)
return
self
.
connector
.
request_finished
(
request
,
block_ids
)
def
_update_actual_load_token_num_from_remote_kv
(
self
,
request
:
Request
)
->
bool
:
num_actual_load_tokens
=
self
.
finished_loading_dict
.
pop
(
request
.
request_id
)
num_computed_tokens
=
num_actual_load_tokens
assert
self
.
connector
is
not
None
if
num_actual_load_tokens
<=
0
and
hasattr
(
self
.
connector
,
"add_failure_request"
):
self
.
connector
.
add_failure_request
(
request
)
return
True
if
num_actual_load_tokens
==
request
.
num_tokens
:
num_computed_tokens
-=
1
self
.
kv_cache_manager
.
cache_blocks
(
request
,
num_computed_tokens
)
# Update the request state for scheduling.
request
.
num_computed_tokens
=
num_computed_tokens
return
True
def
_update_waiting_for_remote_kv
(
self
,
request
:
Request
)
->
bool
:
"""
KV Connector: check if the request_id is finished_recving.
...
...
@@ -1131,9 +1107,6 @@ class Scheduler(SchedulerInterface):
WAITING_FOR_REMOTE_KV.
"""
assert
self
.
connector
is
not
None
if
request
.
request_id
in
self
.
finished_loading_dict
:
return
self
.
_update_actual_load_token_num_from_remote_kv
(
request
)
if
request
.
request_id
not
in
self
.
finished_recving_kv_req_ids
:
return
False
...
...
@@ -1172,6 +1145,3 @@ class Scheduler(SchedulerInterface):
for
req_id
in
(
model_runner_output
.
finished_sending
or
()):
logger
.
debug
(
"Finished sending KV transfer for request %s"
,
req_id
)
self
.
_free_blocks
(
self
.
requests
[
req_id
])
if
model_runner_output
.
finished_loading_dict
:
self
.
finished_loading_dict
.
update
(
model_runner_output
.
finished_loading_dict
)
vllm/v1/outputs.py
View file @
b18b417f
...
...
@@ -107,8 +107,6 @@ class ModelRunnerOutput:
# [req_ids]
finished_sending
:
Optional
[
set
[
str
]]
=
None
finished_recving
:
Optional
[
set
[
str
]]
=
None
# req_id -> actual_load_token from connector
finished_loading_dict
:
Optional
[
dict
[
str
,
int
]]
=
None
# req_id -> num_nans_in_logits
num_nans_in_logits
:
Optional
[
dict
[
str
,
int
]]
=
None
...
...
@@ -123,5 +121,4 @@ EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
pooler_output
=
[],
finished_sending
=
None
,
finished_recving
=
None
,
finished_loading_dict
=
None
,
num_nans_in_logits
=
None
)
vllm/v1/worker/gpu_model_runner.py
View file @
b18b417f
...
...
@@ -1375,7 +1375,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_scheduled_tokens_np
:
np
.
ndarray
,
finished_sending
:
Optional
[
set
[
str
]],
finished_recving
:
Optional
[
set
[
str
]],
finished_loading_dict
:
Optional
[
dict
[
str
,
int
]],
)
->
ModelRunnerOutput
:
assert
self
.
input_batch
.
num_reqs
==
\
len
(
self
.
input_batch
.
pooling_params
),
\
...
...
@@ -1412,7 +1411,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pooler_output
=
pooler_output
,
finished_sending
=
finished_sending
,
finished_recving
=
finished_recving
,
finished_loading_dict
=
finished_loading_dict
,
)
@
torch
.
inference_mode
()
...
...
@@ -1532,7 +1530,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
maybe_wait_for_kv_save
()
finished_sending
,
finished_recving
=
(
self
.
get_finished_kv_transfers
(
scheduler_output
))
finished_loading_dict
=
self
.
get_finished_loading
(
scheduler_output
)
if
self
.
use_aux_hidden_state_outputs
:
hidden_states
,
aux_hidden_states
=
model_output
...
...
@@ -1550,11 +1547,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if
not
get_pp_group
().
is_last_rank
:
# For mid-pipeline stages, return the hidden states.
if
not
broadcast_pp_output
:
if
(
finished_sending
or
finished_recving
or
finished_loading_dict
):
if
finished_sending
or
finished_recving
:
hidden_states
.
finished_sending
=
finished_sending
hidden_states
.
finished_recving
=
finished_recving
hidden_states
.
finished_loading_dict
=
finished_loading_dict
return
hidden_states
assert
isinstance
(
hidden_states
,
IntermediateTensors
)
get_pp_group
().
send_tensor_dict
(
hidden_states
.
tensors
,
...
...
@@ -1564,7 +1559,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if
self
.
input_batch
.
pooling_params
:
return
self
.
_pool
(
hidden_states
,
num_scheduled_tokens
,
num_scheduled_tokens_np
,
finished_sending
,
finished_recving
,
finished_loading_dict
)
finished_recving
)
sample_hidden_states
=
hidden_states
[
logits_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
...
...
@@ -1716,7 +1711,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pooler_output
=
[],
finished_sending
=
finished_sending
,
finished_recving
=
finished_recving
,
finished_loading_dict
=
finished_loading_dict
,
num_nans_in_logits
=
num_nans_in_logits
,
)
...
...
vllm/v1/worker/gpu_worker.py
View file @
b18b417f
...
...
@@ -359,12 +359,10 @@ class Worker(WorkerBase):
# In case of PP with kv transfer, we need to pass through the
# finished_sending and finished_recving buffers.
new_output
=
EMPTY_MODEL_RUNNER_OUTPUT
if
(
output
.
finished_sending
or
output
.
finished_recving
or
output
.
finished_loading_dict
):
if
output
.
finished_sending
or
output
.
finished_recving
:
new_output
=
copy
.
copy
(
new_output
)
new_output
.
finished_sending
=
output
.
finished_sending
new_output
.
finished_recving
=
output
.
finished_recving
new_output
.
finished_loading_dict
=
output
.
finished_loading_dict
output
=
new_output
assert
isinstance
(
output
,
ModelRunnerOutput
)
...
...
vllm/v1/worker/kv_connector_model_runner_mixin.py
View file @
b18b417f
...
...
@@ -53,14 +53,6 @@ class KVConnectorModelRunnerMixin:
scheduler_output
.
finished_req_ids
)
return
None
,
None
@
staticmethod
def
get_finished_loading
(
scheduler_output
:
"SchedulerOutput"
,
)
->
dict
[
str
,
int
]:
if
has_kv_transfer_group
():
return
get_kv_transfer_group
().
get_finished_loading
(
scheduler_output
)
return
{}
def
kv_connector_no_forward
(
self
,
scheduler_output
:
"SchedulerOutput"
,
vllm_config
:
VllmConfig
)
->
ModelRunnerOutput
:
# KV send/recv even if no work to do.
...
...
@@ -68,14 +60,11 @@ class KVConnectorModelRunnerMixin:
self
.
maybe_setup_kv_connector
(
scheduler_output
)
finished_sending
,
finished_recving
=
(
self
.
get_finished_kv_transfers
(
scheduler_output
))
finished_loading_dict
=
self
.
get_finished_loading
(
scheduler_output
)
if
(
not
finished_sending
and
not
finished_recving
and
not
finished_loading_dict
):
if
not
finished_sending
and
not
finished_recving
:
return
EMPTY_MODEL_RUNNER_OUTPUT
output
=
copy
.
copy
(
EMPTY_MODEL_RUNNER_OUTPUT
)
output
.
finished_sending
=
finished_sending
output
.
finished_recving
=
finished_recving
output
.
finished_loading_dict
=
finished_loading_dict
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment