Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5f1ebb6f
Commit
5f1ebb6f
authored
Apr 23, 2026
by
maxiang
Browse files
[FIX] 解决new_block_ids 为None造成OOM
parent
aef3c487
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
97 additions
and
38 deletions
+97
-38
vllm/distributed/kv_transfer/kv_connector/v1/du/du_swift_connector.py
...uted/kv_transfer/kv_connector/v1/du/du_swift_connector.py
+23
-9
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
...ted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
+26
-11
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+48
-18
No files found.
vllm/distributed/kv_transfer/kv_connector/v1/du/du_swift_connector.py
View file @
5f1ebb6f
...
...
@@ -48,6 +48,7 @@ class ReqMeta:
block_ids_tensor
=
torch
.
tensor
(
block_ids
)
num_blocks
=
block_ids_tensor
.
shape
[
0
]
block_offsets
=
torch
.
arange
(
0
,
block_size
)
slot_mapping
=
block_offsets
.
reshape
((
1
,
block_size
))
+
\
block_ids_tensor
.
reshape
((
num_blocks
,
1
))
*
block_size
slot_mapping
=
slot_mapping
.
flatten
()[:
valid_num_tokens
]
...
...
@@ -70,7 +71,7 @@ class DuSwiftConnectorMetadata(KVConnectorMetadata):
self
,
request_id
:
str
,
token_ids
:
list
[
int
],
block_ids
:
list
[
int
],
block_ids
:
list
[
int
],
#这里为None ??
block_size
:
int
,
)
->
None
:
self
.
requests
.
append
(
...
...
@@ -619,12 +620,15 @@ class DuSwiftConnector(KVConnectorBase_V1):
num_scheduled_tokens
=
(
scheduler_output
.
num_scheduled_tokens
)[
req_id
]
num_tokens
=
(
num_scheduled_tokens
+
num_computed_tokens
)
# assert req_id in self.chunked_prefill
if
req_id
not
in
self
.
chunked_prefill
:
continue
block_ids
=
new_block_ids
[
0
]
delta_block_ids
=
(
[]
if
new_block_ids
is
None
else
new_block_ids
[
0
])
if
not
resumed_from_preemption
:
block_ids
=
(
self
.
chunked_prefill
[
req_id
][
0
]
+
block_ids
)
block_ids
=
(
self
.
chunked_prefill
[
req_id
][
0
]
+
delta_block_ids
)
else
:
block_ids
=
delta_block_ids
prompt_token_ids
=
self
.
chunked_prefill
[
req_id
][
1
]
# the request's prompt is chunked prefill again
if
num_tokens
<
len
(
prompt_token_ids
):
...
...
@@ -644,13 +648,23 @@ class DuSwiftConnector(KVConnectorBase_V1):
if
not
resumed_from_preemption
:
break
if
req_id
in
self
.
_requests_need_load
:
request
,
_
=
self
.
_requests_need_load
.
pop
(
req_id
)
request
,
fallback_block_ids
=
(
self
.
_requests_need_load
.
pop
(
req_id
))
total_tokens
=
num_computed_tokens
+
1
token_ids
=
request
.
all_token_ids
[:
total_tokens
]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
if
new_block_ids
is
not
None
:
block_ids
=
new_block_ids
[
0
]
elif
fallback_block_ids
:
block_ids
=
fallback_block_ids
logger
.
warning
(
"Using fallback block_ids for resumed request "
"%s: new_block_ids is None."
,
req_id
)
else
:
logger
.
warning
(
"Skip KV load meta for resumed request %s: "
"no block_ids available."
,
req_id
)
continue
meta
.
add_request
(
request_id
=
req_id
,
token_ids
=
token_ids
,
...
...
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
View file @
5f1ebb6f
...
...
@@ -477,10 +477,11 @@ class P2pNcclConnector(KVConnectorBase_V1):
"""
Update KVConnector state after block allocation.
"""
#将全量blocks存入字典
if
not
self
.
is_producer
and
num_external_tokens
>
0
:
self
.
_requests_need_load
[
request
.
request_id
]
=
(
request
,
blocks
.
get_block_ids
()[
0
],
blocks
.
get_block_ids
()[
0
],
#转换为block ID 列表 req的全量blocks
)
def
build_connector_meta
(
...
...
@@ -520,6 +521,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
block_size
=
self
.
_block_size
,
)
continue
#新请求
if
new_req
.
req_id
in
self
.
_requests_need_load
:
meta
.
add_request
(
request_id
=
new_req
.
req_id
,
...
...
@@ -538,16 +540,19 @@ class P2pNcclConnector(KVConnectorBase_V1):
if
self
.
is_producer
:
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
num_tokens
=
num_scheduled_tokens
+
num_computed_tokens
# assert req_id in self.chunked_prefill
if
req_id
not
in
self
.
chunked_prefill
:
continue
assert
new
_block_ids
is
not
None
block_ids
=
new_block_ids
[
0
]
delta
_block_ids
=
(
[]
if
new_block_ids
is
None
else
new_block_ids
[
0
]
)
if
not
resumed_from_preemption
:
block_ids
=
self
.
chunked_prefill
[
req_id
][
0
]
+
block_ids
block_ids
=
(
self
.
chunked_prefill
[
req_id
][
0
]
+
delta_block_ids
)
else
:
block_ids
=
delta_block_ids
prompt_token_ids
=
self
.
chunked_prefill
[
req_id
][
1
]
assert
prompt_token_ids
is
not
None
# the request's prompt is chunked prefill again
# ???? 一直累积
if
num_tokens
<
len
(
prompt_token_ids
):
self
.
chunked_prefill
[
req_id
]
=
(
block_ids
,
prompt_token_ids
)
continue
...
...
@@ -563,17 +568,27 @@ class P2pNcclConnector(KVConnectorBase_V1):
# NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs.
if
not
resumed_from_preemption
:
break
if
req_id
in
self
.
_requests_need_load
:
request
,
_
=
self
.
_requests_need_load
.
pop
(
req_id
)
request
,
fallback_block_ids
=
(
self
.
_requests_need_load
.
pop
(
req_id
))
total_tokens
=
num_computed_tokens
+
1
token_ids
=
request
.
all_token_ids
[:
total_tokens
]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
assert
new_block_ids
is
not
None
if
new_block_ids
is
not
None
:
block_ids
=
new_block_ids
[
0
]
elif
fallback_block_ids
:
block_ids
=
fallback_block_ids
logger
.
warning
(
"Using fallback block_ids for resumed request "
"%s: new_block_ids is None."
,
req_id
)
else
:
logger
.
warning
(
"Skip KV load meta for resumed request %s: "
"no block_ids available."
,
req_id
)
continue
meta
.
add_request
(
request_id
=
req_id
,
...
...
vllm/v1/core/sched/scheduler.py
View file @
5f1ebb6f
...
...
@@ -315,6 +315,17 @@ class Scheduler(SchedulerInterface):
pass
return
num_new_tokens
def
_kv_connector_lookahead_for_waiting
(
self
,
request
:
Request
)
->
int
:
if
self
.
connector
is
not
None
and
self
.
connector
.
is_producer
:
return
0
return
0
if
request
.
num_computed_tokens
==
0
else
self
.
num_lookahead_tokens
def
_kv_connector_lookahead_for_running
(
self
)
->
int
:
if
self
.
connector
is
not
None
and
self
.
connector
.
is_producer
:
return
0
return
self
.
num_lookahead_tokens
def
schedule_default
(
self
)
->
SchedulerOutput
:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
...
...
@@ -442,7 +453,7 @@ class Scheduler(SchedulerInterface):
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
request
,
num_new_tokens
,
num_lookahead_tokens
=
self
.
num_lookahead_tokens
,
num_lookahead_tokens
=
self
.
_kv_connector_lookahead_for_running
()
,
)
if
new_blocks
is
not
None
:
...
...
@@ -480,6 +491,7 @@ class Scheduler(SchedulerInterface):
encoder_compute_budget
+=
num_embeds_to_restore
req_index
-=
1
else
:
#从运行队列中弹出, 强占
preempted_req
=
self
.
running
.
pop
()
self
.
_preempt_request
(
preempted_req
,
scheduled_timestamp
)
...
...
@@ -564,6 +576,7 @@ class Scheduler(SchedulerInterface):
continue
# KVTransfer: skip request if still waiting for remote kvs.
if
request
.
status
==
RequestStatus
.
WAITING_FOR_REMOTE_KVS
:
#检查kv cache是否已经传输完毕
is_ready
=
self
.
_update_waiting_for_remote_kv
(
request
)
if
is_ready
:
if
request
.
num_preemptions
:
...
...
@@ -577,6 +590,7 @@ class Scheduler(SchedulerInterface):
"%s is still in WAITING_FOR_REMOTE_KVS state."
,
request
.
request_id
,
)
#如果依然没有传输完毕, 将request设置为可跳过的
self
.
waiting
.
pop_request
()
skipped_waiting_requests
.
prepend_request
(
request
)
continue
...
...
@@ -667,7 +681,9 @@ class Scheduler(SchedulerInterface):
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens.
#总prompt tokens - 已计算tokens
num_new_tokens
=
request
.
num_tokens
-
num_computed_tokens
#长prefill token阈值 max-num-batched-tokens
threshold
=
self
.
scheduler_config
.
long_prefill_token_threshold
if
0
<
threshold
<
num_new_tokens
:
num_new_tokens
=
threshold
...
...
@@ -702,7 +718,7 @@ class Scheduler(SchedulerInterface):
if
num_new_tokens
==
0
:
# The request cannot be scheduled.
break
#区分FA 与 mamba , block size 意义不同
if
self
.
need_mamba_block_aligned_split
:
num_new_tokens
=
self
.
_mamba_block_aligned_split
(
request
,
...
...
@@ -718,16 +734,16 @@ class Scheduler(SchedulerInterface):
# extra block gets allocated which
# creates a mismatch between the number
# of local and remote blocks.
effective_lookahead_tokens
=
(
0
if
request
.
num_computed_tokens
==
0
else
self
.
num_lookahead_tokens
effective_lookahead_tokens
=
self
.
_kv_connector_lookahead_for_waiting
(
request
)
num_encoder_tokens
=
(
self
.
_num_encoder_max_input_tokens
if
self
.
is_encoder_decoder
and
request
.
has_encoder_inputs
else
0
)
#分配新的blocks, 这里只有memory pool已经没有free block才返回None,
# 如果是不需要分配新block, new_blocks是空列表
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
request
,
num_new_tokens
,
...
...
@@ -755,6 +771,7 @@ class Scheduler(SchedulerInterface):
if
self
.
connector
is
not
None
:
self
.
connector
.
update_state_after_alloc
(
request
,
#获取这个请求的全量 blocks
self
.
kv_cache_manager
.
get_blocks
(
request
.
request_id
),
num_external_computed_tokens
,
)
...
...
@@ -779,12 +796,14 @@ class Scheduler(SchedulerInterface):
if
request
.
status
==
RequestStatus
.
WAITING
:
scheduled_new_reqs
.
append
(
request
)
elif
request
.
status
==
RequestStatus
.
PREEMPTED
:
#恢复的请求
scheduled_resumed_reqs
.
append
(
request
)
else
:
raise
RuntimeError
(
f
"Invalid request status:
{
request
.
status
}
"
)
if
self
.
lora_config
and
request
.
lora_request
:
scheduled_loras
.
add
(
request
.
lora_request
.
lora_int_id
)
#这里会将恢复或者新请求 重新分配 block
req_to_new_blocks
[
request
.
request_id
]
=
(
self
.
kv_cache_manager
.
get_blocks
(
request
.
request_id
)
)
...
...
@@ -951,6 +970,7 @@ class Scheduler(SchedulerInterface):
request
=
self
.
waiting
.
peek_request
()
# KVTransfer: skip request if still waiting for remote kvs.
#如果该请求为等待远端KV CACHE
if
request
.
status
==
RequestStatus
.
WAITING_FOR_REMOTE_KVS
:
is_ready
=
self
.
_update_waiting_for_remote_kv
(
request
)
if
is_ready
:
...
...
@@ -1101,13 +1121,8 @@ class Scheduler(SchedulerInterface):
if
num_new_tokens
==
0
:
break
# Handles an edge case when P/D Disaggregation
# is used with Spec Decoding where an
# extra block gets allocated which
# creates a mismatch between the number
# of local and remote blocks.
effective_lookahead_tokens
=
(
0
if
request
.
num_computed_tokens
==
0
else
self
.
num_lookahead_tokens
effective_lookahead_tokens
=
self
.
_kv_connector_lookahead_for_waiting
(
request
)
num_encoder_tokens
=
(
...
...
@@ -1160,6 +1175,7 @@ class Scheduler(SchedulerInterface):
self
.
_update_connector_prefix_cache_stats
(
request
)
self
.
running
.
append
(
request
)
if
self
.
log_stats
:
request
.
record_event
(
EngineCoreEventType
.
SCHEDULED
,
scheduled_timestamp
...
...
@@ -1173,9 +1189,11 @@ class Scheduler(SchedulerInterface):
if
self
.
lora_config
and
request
.
lora_request
:
scheduled_loras
.
add
(
request
.
lora_request
.
lora_int_id
)
req_to_new_blocks
[
request
.
request_id
]
=
(
self
.
kv_cache_manager
.
get_blocks
(
request
.
request_id
)
)
num_scheduled_tokens
[
request
.
request_id
]
=
num_new_tokens
token_budget
-=
num_new_tokens
request
.
status
=
RequestStatus
.
RUNNING
...
...
@@ -1203,6 +1221,7 @@ class Scheduler(SchedulerInterface):
self
.
waiting
.
prepend_requests
(
skipped_waiting_requests
)
# Next, schedule the RUNNING requests.
#只有当本轮没有调度任何 WAITING/RESUMED 请求时,才进入 RUNNING 调度
if
not
scheduled_new_reqs
and
not
scheduled_resumed_reqs
:
req_index
=
0
while
req_index
<
len
(
self
.
running
)
and
token_budget
>
0
:
...
...
@@ -1232,10 +1251,11 @@ class Scheduler(SchedulerInterface):
continue
num_new_tokens
=
(
request
.
num_tokens_with_spec
request
.
num_tokens_with_spec
# 不开MTP = prompt + output + 1 / prompt + output + 2
+
request
.
num_output_placeholders
-
request
.
num_computed_tokens
-
request
.
num_computed_tokens
#prompt token + output - 1
)
if
0
<
self
.
scheduler_config
.
long_prefill_token_threshold
<
num_new_tokens
:
num_new_tokens
=
self
.
scheduler_config
.
long_prefill_token_threshold
num_new_tokens
=
min
(
num_new_tokens
,
token_budget
)
...
...
@@ -1284,16 +1304,18 @@ class Scheduler(SchedulerInterface):
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
# we do not strictly follow the FCFS scheduling policy and
# allow the lower-priority requests to be scheduled.
print
(
"xiang ----- new_tokens == 0"
)
req_index
+=
1
continue
# Schedule newly needed KV blocks for the request.
with
record_function_or_nullcontext
(
"schedule: allocate_slots"
):
while
True
:
#给Runing请求分配block
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
request
,
num_new_tokens
,
num_lookahead_tokens
=
self
.
num_lookahead_tokens
,
num_lookahead_tokens
=
self
.
_kv_connector_lookahead_for_running
()
,
)
if
new_blocks
is
not
None
:
...
...
@@ -1512,8 +1534,10 @@ class Scheduler(SchedulerInterface):
assert
request
.
status
==
RequestStatus
.
RUNNING
,
(
"Only running requests can be preempted"
)
#释放所有block
self
.
kv_cache_manager
.
free
(
request
)
self
.
encoder_cache_manager
.
free
(
request
)
#请求状态设置为被抢占
request
.
status
=
RequestStatus
.
PREEMPTED
request
.
num_computed_tokens
=
0
request
.
spec_token_ids
.
clear
()
...
...
@@ -1522,6 +1546,7 @@ class Scheduler(SchedulerInterface):
request
.
record_event
(
EngineCoreEventType
.
PREEMPTED
,
timestamp
)
# Put the request back to the waiting queue.
#放入waiting队列队头
self
.
waiting
.
prepend_request
(
request
)
def
_update_after_schedule
(
self
,
scheduler_output
:
SchedulerOutput
)
->
None
:
...
...
@@ -1633,12 +1658,17 @@ class Scheduler(SchedulerInterface):
scheduled_in_prev_step
=
req_id
in
self
.
prev_step_scheduled_req_ids
if
idx
>=
num_running_reqs
:
assert
not
scheduled_in_prev_step
#这里是恢复的请求ID
resumed_req_ids
.
add
(
req_id
)
if
not
scheduled_in_prev_step
:
all_token_ids
[
req_id
]
=
req
.
all_token_ids
.
copy
()
#这里加入新分配的block ids
new_block_ids
.
append
(
req_to_new_blocks
[
req_id
].
get_block_ids
(
allow_none
=
True
)
)
print
(
"new_block_ids : "
,
new_block_ids
)
num_computed_tokens
.
append
(
req
.
num_computed_tokens
)
num_output_tokens
.
append
(
req
.
num_output_tokens
+
req
.
num_output_placeholders
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment