Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
71d1d75b
Unverified
Commit
71d1d75b
authored
Jul 08, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Jul 08, 2025
Browse files
[PD][Nixl] Remote consumer READ timeout for clearing request blocks (#20139)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
72d14d0e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
115 additions
and
10 deletions
+115
-10
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+74
-4
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+32
-5
vllm/envs.py
vllm/envs.py
+9
-1
No files found.
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
71d1d75b
...
...
@@ -9,10 +9,13 @@ from unittest.mock import patch
import
pytest
from
vllm
import
LLM
from
vllm.config
import
KVTransferConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
KVConnectorRole
,
NixlAgentMetadata
,
NixlConnector
,
NixlConnectorMetadata
,
NixlConnectorWorker
)
from
vllm.forward_context
import
ForwardContext
from
vllm.sampling_params
import
SamplingParams
from
.utils
import
create_request
,
create_scheduler
,
create_vllm_config
...
...
@@ -41,9 +44,9 @@ def test_basic_interface():
assert
kv_connector_metadata
is
not
None
assert
isinstance
(
kv_connector_metadata
,
NixlConnectorMetadata
)
assert
len
(
kv_connector_metadata
.
req
uests
)
==
1
assert
request_id
in
kv_connector_metadata
.
req
uests
req_meta
=
kv_connector_metadata
.
req
uests
[
request_id
]
assert
len
(
kv_connector_metadata
.
req
s_to_recv
)
==
1
assert
request_id
in
kv_connector_metadata
.
req
s_to_recv
req_meta
=
kv_connector_metadata
.
req
s_to_recv
[
request_id
]
for
block_id
,
block
in
zip
(
req_meta
.
local_block_ids
,
scheduler
.
kv_cache_manager
.
coordinator
.
...
...
@@ -78,7 +81,7 @@ def test_prompt_less_than_block_size():
kv_connector_metadata
=
scheduler_output
.
kv_connector_metadata
assert
kv_connector_metadata
is
not
None
assert
isinstance
(
kv_connector_metadata
,
NixlConnectorMetadata
)
assert
len
(
kv_connector_metadata
.
req
uests
)
==
0
assert
len
(
kv_connector_metadata
.
req
s_to_recv
)
==
0
# This request should be scheduled regularly.
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
1
...
...
@@ -371,3 +374,70 @@ class TestNixlHandshake:
if
cnt_finished_reqs
==
total_reqs
:
return
raise
TimeoutError
(
"Took too long to complete async handshake."
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
)
def
test_abort_timeout_on_prefiller
(
monkeypatch
):
"""
Test lifecycle of an aborted Remote Prefill request hitting the timeout.
-----> P
| {process request}
<-\--- | {result is NOT delivered, eg proxy is down}
|
|
| {eventually free blocks}
"""
model_name
=
"Qwen/Qwen3-0.6B"
kv_transfer_config
=
KVTransferConfig
(
kv_connector
=
"NixlConnector"
,
kv_role
=
"kv_both"
,
)
timeout
=
6
monkeypatch
.
setenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"0"
)
monkeypatch
.
setenv
(
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT"
,
str
(
timeout
))
llm
=
LLM
(
model
=
model_name
,
enforce_eager
=
True
,
gpu_memory_utilization
=
0.5
,
kv_transfer_config
=
kv_transfer_config
,
)
remote_prefill_opts
=
{
"do_remote_decode"
:
True
,
"do_remote_prefill"
:
False
,
"remote_engine_id"
:
None
,
"remote_block_ids"
:
None
,
"remote_host"
:
None
,
"remote_port"
:
None
,
}
# Simulate sidecar request
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
1
,
extra_args
=
{
"kv_transfer_params"
:
remote_prefill_opts
})
scheduler
=
llm
.
llm_engine
.
engine_core
.
engine_core
.
scheduler
req_to_blocks
=
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
req_to_blocks
padding
=
"Just making this request a little longer so that we're sure "
"we're not hitting the small-request lower bound beneath which we don't "
"actually trigger the whole kv transfer, but rather just recompute the "
"blocks on D."
_
=
llm
.
generate
([
f
"What is the capital of Japan?
{
padding
}
"
],
sampling_params
)
# Request finished but not freed
assert
'0'
in
scheduler
.
finished_req_ids
and
'0'
in
req_to_blocks
# Some other request, 0 still not freed
_
=
llm
.
generate
([
f
"What is the capital of Italy?
{
padding
}
"
],
sampling_params
)
assert
'0'
in
req_to_blocks
assert
'1'
in
scheduler
.
finished_req_ids
and
'1'
in
req_to_blocks
# Wait for timeout and trigger another scheduler loop
time
.
sleep
(
timeout
)
_
=
llm
.
generate
([
f
"What is the capital of France?
{
padding
}
"
],
sampling_params
)
# Request-0 times out and is cleared!
assert
'0'
not
in
req_to_blocks
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
71d1d75b
...
...
@@ -79,7 +79,8 @@ class ReqMeta:
class
NixlConnectorMetadata
(
KVConnectorMetadata
):
def
__init__
(
self
):
self
.
requests
:
dict
[
ReqId
,
ReqMeta
]
=
{}
self
.
reqs_to_recv
:
dict
[
ReqId
,
ReqMeta
]
=
{}
self
.
reqs_to_send
:
dict
[
ReqId
,
float
]
=
{}
def
add_new_req
(
self
,
...
...
@@ -87,7 +88,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
local_block_ids
:
list
[
int
],
kv_transfer_params
:
dict
[
str
,
Any
],
):
self
.
req
uests
[
request_id
]
=
ReqMeta
(
self
.
req
s_to_recv
[
request_id
]
=
ReqMeta
(
local_block_ids
=
local_block_ids
,
remote_block_ids
=
kv_transfer_params
[
"remote_block_ids"
],
remote_engine_id
=
kv_transfer_params
[
"remote_engine_id"
],
...
...
@@ -194,10 +195,12 @@ class NixlConnectorScheduler:
vllm_config
.
parallel_config
.
tensor_parallel_size
)
logger
.
info
(
"Initializing NIXL Scheduler %s"
,
engine_id
)
# Requests that need to start recv.
# Requests that need to start recv
/send
.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self
.
_reqs_need_recv
:
dict
[
ReqId
,
tuple
[
Request
,
list
[
int
]]]
=
{}
# Reqs to send and their expiration time
self
.
_reqs_need_send
:
dict
[
ReqId
,
float
]
=
{}
def
get_num_new_matched_tokens
(
self
,
request
:
"Request"
,
...
...
@@ -284,6 +287,9 @@ class NixlConnectorScheduler:
# Clear the list once workers start the transfers
self
.
_reqs_need_recv
.
clear
()
meta
.
reqs_to_send
=
self
.
_reqs_need_send
self
.
_reqs_need_send
=
{}
return
meta
def
request_finished
(
...
...
@@ -325,6 +331,11 @@ class NixlConnectorScheduler:
# If prompt < block_size, no xfer so free blocks immediately.
delay_free_blocks
=
len
(
computed_block_ids
)
>
0
if
delay_free_blocks
:
# Prefill request on remote. It will be read from D upon completion
self
.
_reqs_need_send
[
request
.
request_id
]
=
time
.
perf_counter
(
)
+
envs
.
VLLM_NIXL_ABORT_REQUEST_TIMEOUT
return
delay_free_blocks
,
dict
(
do_remote_prefill
=
True
,
do_remote_decode
=
False
,
...
...
@@ -394,6 +405,8 @@ class NixlConnectorWorker:
# In progress transfers.
# [req_id -> list[handle]]
self
.
_recving_transfers
=
defaultdict
[
ReqId
,
list
[
Transfer
]](
list
)
# Track the expiration time of requests that are waiting to be sent.
self
.
_reqs_to_send
:
dict
[
ReqId
,
float
]
=
{}
# Complete transfer tracker. Used by the rank 0 to track finished
# transactions on ranks 1 to N-1.
...
...
@@ -826,6 +839,16 @@ class NixlConnectorWorker:
"and %s requests done recving"
,
self
.
tp_rank
,
len
(
done_sending
),
len
(
done_recving
))
# Handle timeout to avoid stranding blocks on remote.
now
=
time
.
perf_counter
()
while
self
.
_reqs_to_send
:
req_id
,
expires
=
next
(
iter
(
self
.
_reqs_to_send
.
items
()))
# Sorted dict, oldest requests are put first so we can exit early.
if
now
<
expires
:
break
del
self
.
_reqs_to_send
[
req_id
]
done_sending
.
add
(
req_id
)
if
self
.
world_size
==
1
:
return
done_sending
,
done_recving
...
...
@@ -857,7 +880,7 @@ class NixlConnectorWorker:
all_done_sending
:
set
[
str
]
=
set
()
for
req_id
in
list
(
self
.
_done_sending_count
.
keys
()):
if
self
.
_done_sending_count
[
req_id
]
=
=
self
.
world_size
:
if
self
.
_done_sending_count
[
req_id
]
>
=
self
.
world_size
:
del
self
.
_done_sending_count
[
req_id
]
all_done_sending
.
add
(
req_id
)
...
...
@@ -887,6 +910,7 @@ class NixlConnectorWorker:
tp_ratio
):
notified_req_ids
.
add
(
req_id
)
del
self
.
consumer_notification_counts_by_req
[
req_id
]
del
self
.
_reqs_to_send
[
req_id
]
return
notified_req_ids
def
_pop_done_transfers
(
...
...
@@ -921,7 +945,7 @@ class NixlConnectorWorker:
Start loading by triggering non-blocking nixl_xfer.
We check for these trnxs to complete in each step().
"""
for
req_id
,
meta
in
metadata
.
req
uests
.
items
():
for
req_id
,
meta
in
metadata
.
req
s_to_recv
.
items
():
remote_engine_id
=
meta
.
remote_engine_id
logger
.
debug
(
"start_load_kv for request %s from remote engine %s. "
...
...
@@ -943,6 +967,9 @@ class NixlConnectorWorker:
while
not
self
.
_ready_requests
.
empty
():
self
.
_read_blocks_for_req
(
*
self
.
_ready_requests
.
get_nowait
())
# Add to requests that are waiting to be read and track expiration.
self
.
_reqs_to_send
.
update
(
metadata
.
reqs_to_send
)
def
_read_blocks_for_req
(
self
,
req_id
:
str
,
meta
:
ReqMeta
):
logger
.
debug
(
"Remote agent %s available, calling _read_blocks for req %s"
,
...
...
vllm/envs.py
View file @
71d1d75b
...
...
@@ -138,6 +138,7 @@ if TYPE_CHECKING:
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION
:
str
=
"NONE"
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16
:
bool
=
True
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB
:
Optional
[
int
]
=
None
VLLM_NIXL_ABORT_REQUEST_TIMEOUT
:
int
=
120
def
get_default_cache_root
():
...
...
@@ -953,7 +954,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
# generations on machines < 100 for compressed-tensors
# models
"VLLM_USE_NVFP4_CT_EMULATIONS"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_NVFP4_CT_EMULATIONS"
,
"0"
)))
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_NVFP4_CT_EMULATIONS"
,
"0"
))),
# Time (in seconds) after which the KV cache on the producer side is
# automatically cleared if no READ notification is received from the
# consumer. This is only applicable when using NixlConnector in a
# disaggregated decode-prefill setup.
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT"
,
"120"
))
}
# --8<-- [end:env-vars-definition]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment