Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
59dd311c
Unverified
Commit
59dd311c
authored
May 14, 2025
by
Nick Hill
Committed by
GitHub
May 14, 2025
Browse files
[KVConnector] Keep KVTransferParams as a dict (#18033)
parent
d066e520
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
64 additions
and
157 deletions
+64
-157
tests/v1/kv_connector/unit/utils.py
tests/v1/kv_connector/unit/utils.py
+12
-14
vllm/distributed/kv_transfer/kv_connector/v1/__init__.py
vllm/distributed/kv_transfer/kv_connector/v1/__init__.py
+2
-2
vllm/distributed/kv_transfer/kv_connector/v1/base.py
vllm/distributed/kv_transfer/kv_connector/v1/base.py
+0
-25
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+34
-95
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+8
-4
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+4
-8
vllm/v1/request.py
vllm/v1/request.py
+4
-9
No files found.
tests/v1/kv_connector/unit/utils.py
View file @
59dd311c
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
from
typing
import
Any
,
Optional
import
torch
from
vllm
import
SamplingParams
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
KVTransferConfig
,
ModelConfig
,
SchedulerConfig
,
VllmConfig
)
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlKVTransferParams
)
from
vllm.v1.core.sched.scheduler
import
Scheduler
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
)
...
...
@@ -124,20 +122,20 @@ def create_request(
)
->
Request
:
"""Make dummy request for testing."""
kv_transfer_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
if
do_remote_decode
:
assert
not
do_remote_prefill
kv_transfer_params
=
NixlKVTransferParams
(
do_remote_prefill
=
False
,
kv_transfer_params
=
dict
(
do_remote_prefill
=
False
,
do_remote_decode
=
True
)
elif
do_remote_prefill
:
kv_transfer_params
=
NixlKVTransferParams
(
do_remote_prefill
=
True
,
kv_transfer_params
=
dict
(
do_remote_prefill
=
True
,
do_remote_decode
=
False
,
remote_engine_id
=
"my-engine-id"
,
remote_block_ids
=
list
(
range
(
num_remote_blocks
)),
remote_block_ids
=
list
(
range
(
num_remote_blocks
)),
remote_host
=
"my-host"
,
remote_port
=
1234
)
else
:
kv_transfer_params
=
None
max_tokens
=
1
if
do_remote_decode
else
max_tokens
sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/__init__.py
View file @
59dd311c
# SPDX-License-Identifier: Apache-2.0
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorRole
,
KVTransferParams
)
KVConnectorBase_V1
,
KVConnectorRole
)
__all__
=
[
"KVConnectorRole"
,
"KVConnectorBase_V1"
,
"KVTransferParams"
]
__all__
=
[
"KVConnectorRole"
,
"KVConnectorBase_V1"
]
vllm/distributed/kv_transfer/kv_connector/v1/base.py
View file @
59dd311c
...
...
@@ -48,23 +48,6 @@ class KVConnectorRole(enum.Enum):
WORKER
=
1
class
KVTransferParams
:
"""
Abstract KVTransferParams used to send KVTransfer
parameters between instances of vLLM.
Specific instances of KVConnector customize this
method for serializing / deserializing msgs sent
via the HTTP protocol.
"""
@
staticmethod
def
from_raw_dict
(
raw_dict
:
Optional
[
dict
[
str
,
Any
]])
->
Optional
[
"KVTransferParams"
]:
return
None
@
dataclass
class
KVConnectorMetadata
:
"""
...
...
@@ -75,7 +58,6 @@ class KVConnectorMetadata:
class
KVConnectorBase_V1
(
ABC
):
_KVTransferParams
=
KVTransferParams
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
):
logger
.
warning
(
...
...
@@ -213,13 +195,6 @@ class KVConnectorBase_V1(ABC):
# Scheduler-side methods
# ==============================
def
set_kv_transfer_params
(
self
,
request
:
"Request"
):
"""Parse raw KV Transfer params."""
assert
request
.
kv_transfer_params
is
None
kv_transfer_params
=
self
.
_KVTransferParams
.
from_raw_dict
(
request
.
raw_kv_transfer_params
)
request
.
kv_transfer_params
=
kv_transfer_params
@
abstractmethod
def
get_num_new_matched_tokens
(
self
,
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
59dd311c
...
...
@@ -16,7 +16,7 @@ import zmq
from
vllm
import
envs
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
,
KVTransferParams
)
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
)
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tp_group
)
...
...
@@ -44,56 +44,6 @@ except ImportError:
NixlWrapper
=
None
@
dataclass
class
NixlKVTransferParams
(
KVTransferParams
):
def
__init__
(
self
,
do_remote_prefill
:
bool
,
do_remote_decode
:
bool
,
remote_block_ids
:
Optional
[
list
[
int
]]
=
None
,
remote_host
:
Optional
[
str
]
=
None
,
remote_port
:
Optional
[
int
]
=
None
,
remote_engine_id
:
Optional
[
str
]
=
None
,
):
self
.
do_remote_prefill
=
do_remote_prefill
self
.
do_remote_decode
=
do_remote_decode
self
.
remote_block_ids
=
remote_block_ids
self
.
remote_host
=
remote_host
self
.
remote_port
=
remote_port
self
.
remote_engine_id
=
remote_engine_id
@
staticmethod
def
from_raw_dict
(
raw_dict
:
Optional
[
dict
[
str
,
Any
]])
->
Optional
[
"NixlKVTransferParams"
]:
# If no raw transfer params passed, return None.
if
raw_dict
is
None
:
return
None
# Validate the request is formatted properly.
if
((
"do_remote_prefill"
not
in
raw_dict
)
or
(
"do_remote_decode"
not
in
raw_dict
)
or
(
"remote_block_ids"
not
in
raw_dict
)
or
(
"remote_host"
not
in
raw_dict
)
or
(
"remote_port"
not
in
raw_dict
)
or
(
"remote_engine_id"
not
in
raw_dict
)):
logger
.
warning
(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer"
,
raw_dict
)
return
None
return
NixlKVTransferParams
(
do_remote_prefill
=
raw_dict
[
"do_remote_prefill"
],
do_remote_decode
=
raw_dict
[
"do_remote_decode"
],
remote_block_ids
=
raw_dict
[
"remote_block_ids"
],
remote_host
=
raw_dict
[
"remote_host"
],
remote_port
=
raw_dict
[
"remote_port"
],
remote_engine_id
=
raw_dict
[
"remote_engine_id"
],
)
class
NixlAgentMetadata
(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
...
...
@@ -123,25 +73,18 @@ class NixlConnectorMetadata(KVConnectorMetadata):
self
,
request_id
:
str
,
local_block_ids
:
list
[
int
],
kv_transfer_params
:
NixlKVTransferParams
,
kv_transfer_params
:
dict
[
str
,
Any
]
,
):
assert
request_id
not
in
self
.
requests
assert
kv_transfer_params
.
remote_block_ids
is
not
None
assert
kv_transfer_params
.
remote_engine_id
is
not
None
assert
kv_transfer_params
.
remote_host
is
not
None
assert
kv_transfer_params
.
remote_port
is
not
None
self
.
requests
[
request_id
]
=
ReqMeta
(
local_block_ids
=
local_block_ids
,
remote_block_ids
=
kv_transfer_params
.
remote_block_ids
,
remote_engine_id
=
kv_transfer_params
.
remote_engine_id
,
remote_host
=
kv_transfer_params
.
remote_host
,
remote_port
=
kv_transfer_params
.
remote_port
,
remote_block_ids
=
kv_transfer_params
[
"
remote_block_ids
"
]
,
remote_engine_id
=
kv_transfer_params
[
"
remote_engine_id
"
]
,
remote_host
=
kv_transfer_params
[
"
remote_host
"
]
,
remote_port
=
kv_transfer_params
[
"
remote_port
"
]
,
)
class
NixlConnector
(
KVConnectorBase_V1
):
_KVTransferParams
:
type
[
NixlKVTransferParams
]
=
NixlKVTransferParams
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
role
:
KVConnectorRole
):
assert
vllm_config
.
kv_transfer_config
is
not
None
...
...
@@ -253,52 +196,52 @@ class NixlConnectorScheduler:
asynchronously (between scheduler steps).
"""
params
=
request
.
kv_transfer_params
logger
.
debug
(
"NIXLConnector get_num_new_matched_tokens: "
"num_computed_tokens=%s, kv_transfer_params=%s"
,
num_computed_tokens
,
request
.
kv_transfer_params
)
# No KVTransfer for this request.
if
request
.
kv_transfer_params
is
None
:
return
0
,
False
assert
isinstance
(
request
.
kv_transfer_params
,
NixlKVTransferParams
)
num_computed_tokens
,
params
)
if
params
is
not
None
and
params
.
get
(
"do_remote_prefill"
):
# Remote prefill: get all prompt blocks from remote.
if
request
.
kv_transfer_params
.
do_remote_prefill
:
assert
num_computed_tokens
%
self
.
block_size
==
0
rounded_num_prompt_tokens
=
round_down
(
len
(
request
.
prompt_token_ids
),
self
.
block_size
)
count
=
max
(
rounded_num_prompt_tokens
-
num_computed_tokens
,
0
)
return
count
,
count
>
0
# No remote prefill for this request.
return
0
,
False
def
update_state_after_alloc
(
self
,
request
:
"Request"
,
blocks
:
"KVCacheBlocks"
,
num_external_tokens
:
int
):
params
=
request
.
kv_transfer_params
logger
.
debug
(
"NIXLConnector update_state_after_alloc: "
"num_external_tokens=%s, kv_transfer_params=%s"
,
num_external_tokens
,
request
.
kv_transfer_
params
)
num_external_tokens
,
params
)
if
request
.
kv_transfer_params
is
None
:
return
assert
isinstance
(
request
.
kv_transfer_params
,
NixlKVTransferParams
)
if
request
.
kv_transfer_params
.
do_remote_prefill
:
if
params
is
not
None
and
params
.
get
(
"do_remote_prefill"
):
# NOTE(rob): if prompt < block_size, no remote blocks
# since the remote only sends fully computed blocks, so
# skip recving for this request. num_external_tokens
# should be 0 if there are no remote blocks.
if
request
.
kv_transfer_params
.
remote_block_ids
:
if
params
.
get
(
"remote_block_ids"
):
if
all
(
p
in
params
for
p
in
(
"remote_engine_id"
,
"remote_host"
,
"remote_port"
)):
# Get unhashed blocks to pull from remote.
self
.
_reqs_need_recv
[
request
.
request_id
]
=
(
request
,
blocks
.
get_unhashed_block_ids
())
else
:
logger
.
warning
(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer"
,
params
)
else
:
assert
num_external_tokens
==
0
# Only trigger 1 KV transfer per request.
request
.
kv_transfer_
params
.
do_remote_prefill
=
False
params
[
"
do_remote_prefill
"
]
=
False
def
build_connector_meta
(
self
,
...
...
@@ -308,7 +251,7 @@ class NixlConnectorScheduler:
# Loop through scheduled reqs and convert to ReqMeta.
for
req_id
,
(
req
,
block_ids
)
in
self
.
_reqs_need_recv
.
items
():
assert
isinstance
(
req
.
kv_transfer_params
,
NixlKVTransferParams
)
assert
req
.
kv_transfer_params
is
not
None
meta
.
add_new_req
(
request_id
=
req_id
,
local_block_ids
=
block_ids
,
...
...
@@ -330,34 +273,30 @@ class NixlConnectorScheduler:
should be freed now or will be sent asynchronously and freed later.
"""
params
=
request
.
kv_transfer_params
logger
.
debug
(
"NIXLConnector request_finished, "
"request_status=%s, kv_transfer_params=%s"
,
request
.
status
,
request
.
kv_transfer_params
)
"NIXLConnector request_finished, request_status=%s, "
"kv_transfer_params=%s"
,
request
.
status
,
params
)
if
request
.
kv_transfer_params
is
None
:
return
False
,
None
assert
isinstance
(
request
.
kv_transfer_params
,
NixlKVTransferParams
)
if
((
not
request
.
kv_transfer_params
.
do_remote_decode
)
or
(
request
.
status
!=
RequestStatus
.
FINISHED_LENGTH_CAPPED
)):
if
(
params
is
None
or
not
params
.
get
(
"do_remote_decode"
)
or
request
.
status
!=
RequestStatus
.
FINISHED_LENGTH_CAPPED
):
return
False
,
None
# Get computed blocks.
all_full
=
request
.
num_computed_tokens
%
self
.
block_size
==
0
computed_block_ids
=
(
block_ids
if
all_full
else
block_ids
[:
-
1
]
)
computed_block_ids
=
block_ids
if
all_full
else
block_ids
[:
-
1
]
# If prompt < block_size, no xfer so free blocks immediately.
delay_free_blocks
=
len
(
computed_block_ids
)
>
0
return
delay_free_blocks
,
NixlKVTransferParams
(
return
delay_free_blocks
,
dict
(
do_remote_prefill
=
True
,
do_remote_decode
=
False
,
remote_block_ids
=
computed_block_ids
,
remote_engine_id
=
self
.
engine_id
,
remote_host
=
envs
.
VLLM_NIXL_SIDE_CHANNEL_HOST
,
remote_port
=
envs
.
VLLM_NIXL_SIDE_CHANNEL_PORT
,
)
.
__dict__
)
class
NixlConnectorWorker
:
...
...
vllm/v1/core/sched/scheduler.py
View file @
59dd311c
...
...
@@ -12,8 +12,7 @@ from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
from
vllm.distributed.kv_transfer.kv_connector.factory
import
(
KVConnectorFactory
)
from
vllm.distributed.kv_transfer.kv_connector.v1
import
(
KVConnectorBase_V1
,
KVConnectorRole
,
KVTransferParams
)
KVConnectorRole
)
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
...
...
@@ -931,8 +930,13 @@ class Scheduler(SchedulerInterface):
return
self
.
connector
def
_connector_finished
(
self
,
request
:
Request
)
->
tuple
[
bool
,
Optional
[
KVTransferParams
]]:
"""Invoke the KV connector request_finished() method if applicable."""
self
,
request
:
Request
)
->
tuple
[
bool
,
Optional
[
dict
[
str
,
Any
]]]:
"""
Invoke the KV connector request_finished() method if applicable.
Returns optional kv transfer parameters to be included with the
request outputs.
"""
if
self
.
connector
is
None
:
return
False
,
None
block_ids
=
self
.
kv_cache_manager
.
get_block_ids
(
request
.
request_id
)
...
...
vllm/v1/engine/core.py
View file @
59dd311c
...
...
@@ -182,13 +182,9 @@ class EngineCore:
# Start grammar compilation asynchronously
self
.
structured_output_manager
.
grammar_init
(
req
)
if
req
.
raw_kv_transfer_params
is
not
None
:
if
(
kv_connector
:
=
self
.
scheduler
.
get_kv_connector
()):
# Parse raw KV transfer params via connector.
kv_connector
.
set_kv_transfer_params
(
req
)
else
:
logger
.
warning
(
"Got KVTransferParams, but no KVConnector found. "
if
req
.
kv_transfer_params
is
not
None
and
(
not
self
.
scheduler
.
get_kv_connector
()):
logger
.
warning
(
"Got kv_transfer_params, but no KVConnector found. "
"Disabling KVTransfer for this request."
)
self
.
scheduler
.
add_request
(
req
)
...
...
vllm/v1/request.py
View file @
59dd311c
...
...
@@ -3,7 +3,6 @@
import
enum
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
from
vllm.distributed.kv_transfer.kv_connector.v1
import
KVTransferParams
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
is_list_of
...
...
@@ -62,14 +61,10 @@ class Request:
self
.
num_encoder_inputs
=
len
(
self
.
mm_inputs
)
self
.
has_encoder_inputs
=
self
.
num_encoder_inputs
>
0
# P/D: KV transfer parameters (raw and parsed).
raw_params
=
(
None
if
sampling_params
.
extra_args
is
None
else
sampling_params
.
extra_args
.
get
(
"kv_transfer_params"
,
None
))
self
.
raw_kv_transfer_params
:
Optional
[
dict
[
str
,
Any
]]
=
raw_params
# Each connector parses the raw dictionary and sets this
# attr the first time that the request is processed.
self
.
kv_transfer_params
:
Optional
[
KVTransferParams
]
=
None
# P/D: Connector-specific KV transfer parameters.
kv_params
=
(
None
if
sampling_params
.
extra_args
is
None
else
sampling_params
.
extra_args
.
get
(
"kv_transfer_params"
))
self
.
kv_transfer_params
:
Optional
[
dict
[
str
,
Any
]]
=
kv_params
# Sanity check
assert
len
(
self
.
mm_inputs
)
==
len
(
self
.
mm_positions
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment