Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
16fb668b
Unverified
Commit
16fb668b
authored
Aug 11, 2025
by
GuanLuo
Committed by
GitHub
Aug 11, 2025
Browse files
fix: NIXL connector transfers partial block to pass full multi-modal context (#21074)
Signed-off-by:
GuanLuo
<
gluo@nvidia.com
>
parent
f7dcce7a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
130 additions
and
41 deletions
+130
-41
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+8
-10
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
+14
-9
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
+99
-5
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+9
-17
No files found.
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
16fb668b
...
@@ -173,9 +173,9 @@ def test_prompt_less_than_block_size():
...
@@ -173,9 +173,9 @@ def test_prompt_less_than_block_size():
"""
"""
Test that we can handle case where prompt is < block.
Test that we can handle case where prompt is < block.
In this case, the P worker will s
end empty
remote_block_ids
.
In this case, the P worker will s
till send
remote_block_ids
of the
The D worker should
not
schedule an async read
in this case,
partial block.
The D worker should schedule an async read
s
in
ce
th
ere is nothing to pull
.
in th
is case
.
"""
"""
vllm_config
=
create_vllm_config
()
vllm_config
=
create_vllm_config
()
scheduler
=
create_scheduler
(
vllm_config
)
scheduler
=
create_scheduler
(
vllm_config
)
...
@@ -184,22 +184,20 @@ def test_prompt_less_than_block_size():
...
@@ -184,22 +184,20 @@ def test_prompt_less_than_block_size():
BLOCK_SIZE
=
vllm_config
.
cache_config
.
block_size
BLOCK_SIZE
=
vllm_config
.
cache_config
.
block_size
NUM_TOKENS
=
int
(
BLOCK_SIZE
*
0.5
)
NUM_TOKENS
=
int
(
BLOCK_SIZE
*
0.5
)
# Request will have
0
remote block
s
.
# Request will have
1 partial
remote block.
request
=
create_request
(
request_id
=
1
,
request
=
create_request
(
request_id
=
1
,
num_tokens
=
NUM_TOKENS
,
num_tokens
=
NUM_TOKENS
,
do_remote_prefill
=
True
,
do_remote_prefill
=
True
,
num_remote_blocks
=
0
)
num_remote_blocks
=
1
)
scheduler
.
add_request
(
request
)
scheduler
.
add_request
(
request
)
scheduler_output
=
scheduler
.
schedule
()
scheduler_output
=
scheduler
.
schedule
()
# This request
should not have to
read async.
# This request
will
read async.
kv_connector_metadata
=
scheduler_output
.
kv_connector_metadata
kv_connector_metadata
=
scheduler_output
.
kv_connector_metadata
assert
kv_connector_metadata
is
not
None
assert
kv_connector_metadata
is
not
None
assert
isinstance
(
kv_connector_metadata
,
NixlConnectorMetadata
)
assert
isinstance
(
kv_connector_metadata
,
NixlConnectorMetadata
)
assert
len
(
kv_connector_metadata
.
reqs_to_recv
)
==
0
assert
len
(
kv_connector_metadata
.
reqs_to_recv
)
==
1
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
0
# This request should be scheduled regularly.
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
1
class
FakeNixlConnectorWorker
(
NixlConnectorWorker
):
class
FakeNixlConnectorWorker
(
NixlConnectorWorker
):
...
...
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
View file @
16fb668b
...
@@ -121,13 +121,18 @@ def test_short_prompt_lifecycle():
...
@@ -121,13 +121,18 @@ def test_short_prompt_lifecycle():
model_runner_output
=
create_model_runner_output
(
reqs
=
[
request
])
model_runner_output
=
create_model_runner_output
(
reqs
=
[
request
])
# (1c): update_from_output()
# (1c): update_from_output()
# Since tokens < block_size, there will be no kv xfer.
# Even though tokens < block_size, there will be kv xfer for partial block.
# So this should be cleaned up immediately.
eco
=
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
_
=
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
kv_transfer_params
=
eco
[
0
].
outputs
[
0
].
kv_transfer_params
assert
(
len
(
kv_transfer_params
[
"remote_block_ids"
])
==
1
)
# Confirm we do not have any memory leaks after req lifecycle.
# Confirm we do not have any memory leaks after req lifecycle.
# We need one more call to schedule() to clear data for persistent batch.
# We need to mark sending finish to clear data for persistent batch.
_
=
scheduler
.
schedule
()
scheduler_output
=
scheduler
.
schedule
()
model_runner_output
=
copy
.
deepcopy
(
EMPTY_MODEL_RUNNER_OUTPUT
)
model_runner_output
.
finished_sending
=
[
request
.
request_id
]
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
assert_scheduler_empty
(
scheduler
)
assert_scheduler_empty
(
scheduler
)
...
@@ -169,16 +174,16 @@ def test_prefix_cache_lifecycle():
...
@@ -169,16 +174,16 @@ def test_prefix_cache_lifecycle():
eco
=
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
eco
=
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
kv_transfer_params
=
eco
[
0
].
outputs
[
0
].
kv_transfer_params
kv_transfer_params
=
eco
[
0
].
outputs
[
0
].
kv_transfer_params
# Ensure we send all block ids, even if there is a cache hit.
# Ensure we send all block ids, including the partial blocks,
# even if there is a cache hit.
assert
(
len
(
assert
(
len
(
kv_transfer_params
[
"remote_block_ids"
])
==
NUM_EXTERNAL_FULL_BLOCKS
)
kv_transfer_params
[
"remote_block_ids"
])
==
(
NUM_EXTERNAL_FULL_BLOCKS
+
1
))
# STEP (2): Ensure it is freed.
# STEP (2): Ensure it is freed.
scheduler_output
=
scheduler
.
schedule
()
scheduler_output
=
scheduler
.
schedule
()
scheduler
.
schedule
()
model_runner_output
=
copy
.
deepcopy
(
EMPTY_MODEL_RUNNER_OUTPUT
)
model_runner_output
=
copy
.
deepcopy
(
EMPTY_MODEL_RUNNER_OUTPUT
)
model_runner_output
.
kv_connector_output
=
KVConnectorOutput
(
model_runner_output
.
kv_connector_output
=
KVConnectorOutput
(
finished_sending
=
[
request_remote
.
request_id
])
finished_sending
=
[
request_remote
.
request_id
])
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
_
=
scheduler
.
schedule
()
assert_scheduler_empty
(
scheduler
)
assert_scheduler_empty
(
scheduler
)
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
View file @
16fb668b
...
@@ -362,7 +362,7 @@ def test_cannot_schedule_after_recv():
...
@@ -362,7 +362,7 @@ def test_cannot_schedule_after_recv():
BLOCK_SIZE
=
vllm_config
.
cache_config
.
block_size
BLOCK_SIZE
=
vllm_config
.
cache_config
.
block_size
# Prompt will use 2 blocks + 1 block after we schedule.
# Prompt will use 2 blocks + 1 block after we schedule.
NUM_TOKENS_LOCAL
=
int
(
BLOCK_SIZE
*
NUM_PROMPT_BLOCKS
)
NUM_TOKENS_LOCAL
=
int
(
BLOCK_SIZE
*
NUM_PROMPT_BLOCKS
)
NUM_TOKENS_REMOTE
=
int
(
BLOCK_SIZE
*
(
NUM_PROMPT_BLOCKS
+
0.5
)
)
NUM_TOKENS_REMOTE
=
int
(
BLOCK_SIZE
*
NUM_PROMPT_BLOCKS
)
request_normal
=
create_request
(
request_id
=
1
,
num_tokens
=
NUM_TOKENS_LOCAL
)
request_normal
=
create_request
(
request_id
=
1
,
num_tokens
=
NUM_TOKENS_LOCAL
)
request_remote
=
create_request
(
request_id
=
2
,
request_remote
=
create_request
(
request_id
=
2
,
...
@@ -393,14 +393,24 @@ def test_cannot_schedule_after_recv():
...
@@ -393,14 +393,24 @@ def test_cannot_schedule_after_recv():
assert
len
(
scheduler
.
running
)
==
1
assert
len
(
scheduler
.
running
)
==
1
assert
len
(
scheduler
.
waiting
)
==
1
assert
len
(
scheduler
.
waiting
)
==
1
# Step 4: try to schedule, not enough blocks.
# Step 4: try to schedule, remote request is put to running list
# because the transfer is completed.
scheduler_output
=
scheduler
.
schedule
()
model_runner_output
=
create_model_runner_output
(
reqs
=
[
request_normal
,
request_remote
])
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
assert
len
(
scheduler
.
running
)
==
2
assert
len
(
scheduler
.
waiting
)
==
0
# Step 5: Remote request will be put back to waiting list
# because it needs new block to hold generated token.
scheduler_output
=
scheduler
.
schedule
()
scheduler_output
=
scheduler
.
schedule
()
model_runner_output
=
create_model_runner_output
(
reqs
=
[
request_normal
])
model_runner_output
=
create_model_runner_output
(
reqs
=
[
request_normal
])
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
assert
len
(
scheduler
.
running
)
==
1
assert
len
(
scheduler
.
running
)
==
1
assert
len
(
scheduler
.
waiting
)
==
1
assert
len
(
scheduler
.
waiting
)
==
1
# Step
5
: finish the request, free it.
# Step
6
: finish the request, free it.
scheduler_output
=
scheduler
.
schedule
()
scheduler_output
=
scheduler
.
schedule
()
model_runner_output
=
create_model_runner_output
(
reqs
=
[
request_normal
],
model_runner_output
=
create_model_runner_output
(
reqs
=
[
request_normal
],
use_eos
=
True
)
use_eos
=
True
)
...
@@ -408,15 +418,99 @@ def test_cannot_schedule_after_recv():
...
@@ -408,15 +418,99 @@ def test_cannot_schedule_after_recv():
assert
len
(
scheduler
.
running
)
==
0
assert
len
(
scheduler
.
running
)
==
0
assert
len
(
scheduler
.
waiting
)
==
1
assert
len
(
scheduler
.
waiting
)
==
1
# Step 6: now we can schedule (with 2 blocks computed).
# Step 7: now we can schedule (with 2 blocks computed),
# request is retrieved from preempted list.
scheduler_output
=
scheduler
.
schedule
()
scheduler_output
=
scheduler
.
schedule
()
model_runner_output
=
create_model_runner_output
(
reqs
=
[
request_remote
])
model_runner_output
=
create_model_runner_output
(
reqs
=
[
request_remote
])
assert
(
scheduler_output
.
scheduled_
new
_reqs
[
0
]
.
num_computed_tokens
==
assert
(
scheduler_output
.
scheduled_
cached
_reqs
.
num_computed_tokens
[
0
]
==
NUM_PROMPT_BLOCKS
*
BLOCK_SIZE
)
NUM_PROMPT_BLOCKS
*
BLOCK_SIZE
)
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
assert
len
(
scheduler
.
running
)
==
1
assert
len
(
scheduler
.
running
)
==
1
assert
len
(
scheduler
.
waiting
)
==
0
assert
len
(
scheduler
.
waiting
)
==
0
# Step 8: free everything.
scheduler_output
=
scheduler
.
schedule
()
model_runner_output
=
create_model_runner_output
(
reqs
=
[
request_remote
],
use_eos
=
True
)
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
_
=
scheduler
.
schedule
()
assert_scheduler_empty
(
scheduler
)
def
test_cannot_recv
():
"""
Test that we can handle no schedule KV block transfer due to not
enough remaining KV blocks.
"""
# NOTE: the KVCacheManager will use 1 null block.
# So there are 5 total working blocks.
TOTAL_NUM_BLOCKS
=
6
vllm_config
=
create_vllm_config
()
scheduler
=
create_scheduler
(
vllm_config
,
num_blocks
=
TOTAL_NUM_BLOCKS
)
# Prime the KVCache.
NUM_PROMPT_BLOCKS
=
2
BLOCK_SIZE
=
vllm_config
.
cache_config
.
block_size
# Prompt will use 2 blocks + 1 block after we schedule.
NUM_TOKENS_LOCAL
=
int
(
BLOCK_SIZE
*
NUM_PROMPT_BLOCKS
)
NUM_TOKENS_REMOTE
=
int
(
BLOCK_SIZE
*
(
NUM_PROMPT_BLOCKS
+
0.5
))
request_normal
=
create_request
(
request_id
=
1
,
num_tokens
=
NUM_TOKENS_LOCAL
)
request_remote
=
create_request
(
request_id
=
2
,
num_tokens
=
NUM_TOKENS_REMOTE
,
do_remote_prefill
=
True
)
# STEP 1: 3 blocks are in use (2 for prompt, 1 for decode).
scheduler
.
add_request
(
request_normal
)
scheduler_output
=
scheduler
.
schedule
()
model_runner_output
=
create_model_runner_output
(
reqs
=
[
request_normal
])
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
assert
len
(
scheduler
.
running
)
==
1
assert
len
(
scheduler
.
waiting
)
==
0
# Step 2: 3 blocks are in use,
# need 3 new for remote blocks but only 2 are available.
scheduler
.
add_request
(
request_remote
)
scheduler_output
=
scheduler
.
schedule
()
model_runner_output
=
create_model_runner_output
(
reqs
=
[
request_normal
])
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
assert
len
(
scheduler
.
running
)
==
1
assert
len
(
scheduler
.
waiting
)
==
1
# Should not have KV transfer in progress.
assert
(
request_remote
.
status
!=
RequestStatus
.
WAITING_FOR_REMOTE_KVS
)
# Step 3: finish the request, free it.
scheduler_output
=
scheduler
.
schedule
()
model_runner_output
=
create_model_runner_output
(
reqs
=
[
request_normal
],
use_eos
=
True
)
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
assert
len
(
scheduler
.
running
)
==
0
assert
len
(
scheduler
.
waiting
)
==
1
# Step 4: now we can initiate KV transfer (with 2 blocks computed).
scheduler_output
=
scheduler
.
schedule
()
model_runner_output
=
create_model_runner_output
(
reqs
=
[])
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
assert
len
(
scheduler
.
running
)
==
0
assert
len
(
scheduler
.
waiting
)
==
1
assert
(
request_remote
.
status
==
RequestStatus
.
WAITING_FOR_REMOTE_KVS
)
# Step 5: finish recving (5 blocks in use)
scheduler_output
=
scheduler
.
schedule
()
model_runner_output
=
create_model_runner_output
(
reqs
=
[],
finished_recving
=
[
request_remote
.
request_id
])
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
assert
len
(
scheduler
.
running
)
==
0
assert
len
(
scheduler
.
waiting
)
==
1
# Step 6: schedule remote request
scheduler_output
=
scheduler
.
schedule
()
model_runner_output
=
create_model_runner_output
(
reqs
=
[
request_remote
])
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
assert
len
(
scheduler
.
running
)
==
1
assert
len
(
scheduler
.
waiting
)
==
0
# Step 7: free everything.
# Step 7: free everything.
scheduler_output
=
scheduler
.
schedule
()
scheduler_output
=
scheduler
.
schedule
()
model_runner_output
=
create_model_runner_output
(
reqs
=
[
request_remote
],
model_runner_output
=
create_model_runner_output
(
reqs
=
[
request_remote
],
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
16fb668b
...
@@ -29,7 +29,7 @@ from vllm.distributed.utils import divide
...
@@ -29,7 +29,7 @@ from vllm.distributed.utils import divide
from
vllm.forward_context
import
ForwardContext
from
vllm.forward_context
import
ForwardContext
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.utils
import
make_zmq_path
,
make_zmq_socket
,
round_down
from
vllm.utils
import
make_zmq_path
,
make_zmq_socket
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.request
import
RequestStatus
from
vllm.v1.request
import
RequestStatus
...
@@ -275,10 +275,7 @@ class NixlConnectorScheduler:
...
@@ -275,10 +275,7 @@ class NixlConnectorScheduler:
if
params
is
not
None
and
params
.
get
(
"do_remote_prefill"
):
if
params
is
not
None
and
params
.
get
(
"do_remote_prefill"
):
# Remote prefill: get all prompt blocks from remote.
# Remote prefill: get all prompt blocks from remote.
assert
num_computed_tokens
%
self
.
block_size
==
0
count
=
len
(
request
.
prompt_token_ids
)
-
num_computed_tokens
rounded_num_prompt_tokens
=
round_down
(
len
(
request
.
prompt_token_ids
),
self
.
block_size
)
count
=
max
(
rounded_num_prompt_tokens
-
num_computed_tokens
,
0
)
if
count
>
0
:
if
count
>
0
:
return
count
,
True
return
count
,
True
...
@@ -301,18 +298,16 @@ class NixlConnectorScheduler:
...
@@ -301,18 +298,16 @@ class NixlConnectorScheduler:
# NOTE: when accelerator is not directly supported by Nixl,
# NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer.
# prefilled blocks need to be saved to host memory before transfer.
#
figure out full computed blocks to save
#
save all blocks
block_ids
=
blocks
.
get_block_ids
()[
0
]
block_ids
=
blocks
.
get_block_ids
()[
0
]
all_full
=
request
.
num_tokens
%
self
.
block_size
==
0
full_block_ids
=
(
block_ids
if
all_full
else
block_ids
[:
-
1
])
# TODO: skip the blocks that are already in the host xfer buffer.
# TODO: skip the blocks that are already in the host xfer buffer.
# Currently, the host xfer buffer block is 1-to-1 mapped to device
# Currently, the host xfer buffer block is 1-to-1 mapped to device
# kv blocks, so host blocks won't be flushed as long as its device
# kv blocks, so host blocks won't be flushed as long as its device
# block is not overwritten; and it will be safe to skip saving them
# block is not overwritten; and it will be safe to skip saving them
# to host xfer buffer.
# to host xfer buffer.
if
full_
block_ids
:
if
block_ids
:
self
.
_reqs_need_save
[
request
.
request_id
]
=
\
self
.
_reqs_need_save
[
request
.
request_id
]
=
\
(
request
,
full_
block_ids
)
(
request
,
block_ids
)
elif
params
.
get
(
"do_remote_prefill"
):
elif
params
.
get
(
"do_remote_prefill"
):
if
params
.
get
(
"remote_block_ids"
):
if
params
.
get
(
"remote_block_ids"
):
if
all
(
p
in
params
for
p
in
(
"remote_engine_id"
,
"remote_host"
,
if
all
(
p
in
params
for
p
in
(
"remote_engine_id"
,
"remote_host"
,
...
@@ -401,12 +396,9 @@ class NixlConnectorScheduler:
...
@@ -401,12 +396,9 @@ class NixlConnectorScheduler:
or
request
.
status
!=
RequestStatus
.
FINISHED_LENGTH_CAPPED
):
or
request
.
status
!=
RequestStatus
.
FINISHED_LENGTH_CAPPED
):
return
False
,
None
return
False
,
None
# Get computed blocks.
# TODO: check whether block_ids actually ever be 0. If not we could
all_full
=
request
.
num_computed_tokens
%
self
.
block_size
==
0
# remove the conditional below
computed_block_ids
=
block_ids
if
all_full
else
block_ids
[:
-
1
]
delay_free_blocks
=
len
(
block_ids
)
>
0
# If prompt < block_size, no xfer so free blocks immediately.
delay_free_blocks
=
len
(
computed_block_ids
)
>
0
if
delay_free_blocks
:
if
delay_free_blocks
:
# Prefill request on remote. It will be read from D upon completion
# Prefill request on remote. It will be read from D upon completion
...
@@ -416,7 +408,7 @@ class NixlConnectorScheduler:
...
@@ -416,7 +408,7 @@ class NixlConnectorScheduler:
return
delay_free_blocks
,
dict
(
return
delay_free_blocks
,
dict
(
do_remote_prefill
=
True
,
do_remote_prefill
=
True
,
do_remote_decode
=
False
,
do_remote_decode
=
False
,
remote_block_ids
=
computed_
block_ids
,
remote_block_ids
=
block_ids
,
remote_engine_id
=
self
.
engine_id
,
remote_engine_id
=
self
.
engine_id
,
remote_host
=
self
.
side_channel_host
,
remote_host
=
self
.
side_channel_host
,
remote_port
=
self
.
side_channel_port
,
remote_port
=
self
.
side_channel_port
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment