Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d0132f02
Unverified
Commit
d0132f02
authored
Jun 23, 2025
by
lkchen
Committed by
GitHub
Jun 23, 2025
Browse files
[Misc] Add type alias `ReqId` and `EngineId` for better readability (#19880)
Signed-off-by:
Linkun Chen
<
github@lkchen.net
>
parent
61f4fc5d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
17 deletions
+20
-17
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+20
-17
No files found.
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
d0132f02
...
@@ -36,6 +36,8 @@ if TYPE_CHECKING:
...
@@ -36,6 +36,8 @@ if TYPE_CHECKING:
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
Transfer
=
tuple
[
int
,
float
]
# (xfer_handle, start_time)
Transfer
=
tuple
[
int
,
float
]
# (xfer_handle, start_time)
EngineId
=
str
ReqId
=
str
GET_META_MSG
=
b
"get_meta_msg"
GET_META_MSG
=
b
"get_meta_msg"
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -75,7 +77,7 @@ class ReqMeta:
...
@@ -75,7 +77,7 @@ class ReqMeta:
class
NixlConnectorMetadata
(
KVConnectorMetadata
):
class
NixlConnectorMetadata
(
KVConnectorMetadata
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
requests
:
dict
[
str
,
ReqMeta
]
=
{}
self
.
requests
:
dict
[
ReqId
,
ReqMeta
]
=
{}
def
add_new_req
(
def
add_new_req
(
self
,
self
,
...
@@ -96,16 +98,17 @@ class NixlConnector(KVConnectorBase_V1):
...
@@ -96,16 +98,17 @@ class NixlConnector(KVConnectorBase_V1):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
role
:
KVConnectorRole
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
role
:
KVConnectorRole
):
assert
vllm_config
.
kv_transfer_config
is
not
None
assert
vllm_config
.
kv_transfer_config
is
not
None
self
.
engine_id
=
vllm_config
.
kv_transfer_config
.
engine_id
assert
vllm_config
.
kv_transfer_config
.
engine_id
is
not
None
self
.
engine_id
:
EngineId
=
vllm_config
.
kv_transfer_config
.
engine_id
if
role
==
KVConnectorRole
.
SCHEDULER
:
if
role
==
KVConnectorRole
.
SCHEDULER
:
self
.
connector_scheduler
:
Optional
[
NixlConnectorScheduler
]
=
\
self
.
connector_scheduler
:
Optional
[
NixlConnectorScheduler
]
=
\
NixlConnectorScheduler
(
vllm_config
,
str
(
self
.
engine_id
)
)
NixlConnectorScheduler
(
vllm_config
,
self
.
engine_id
)
self
.
connector_worker
:
Optional
[
NixlConnectorWorker
]
=
None
self
.
connector_worker
:
Optional
[
NixlConnectorWorker
]
=
None
elif
role
==
KVConnectorRole
.
WORKER
:
elif
role
==
KVConnectorRole
.
WORKER
:
self
.
connector_scheduler
=
None
self
.
connector_scheduler
=
None
self
.
connector_worker
=
NixlConnectorWorker
(
self
.
connector_worker
=
NixlConnectorWorker
(
vllm_config
,
str
(
self
.
engine_id
)
)
vllm_config
,
self
.
engine_id
)
############################################################
############################################################
# Scheduler Side Methods
# Scheduler Side Methods
...
@@ -179,7 +182,7 @@ class NixlConnectorScheduler:
...
@@ -179,7 +182,7 @@ class NixlConnectorScheduler:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_id
:
str
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_id
:
str
):
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
engine_id
=
engine_id
self
.
engine_id
:
EngineId
=
engine_id
self
.
side_channel_host
=
envs
.
VLLM_NIXL_SIDE_CHANNEL_HOST
self
.
side_channel_host
=
envs
.
VLLM_NIXL_SIDE_CHANNEL_HOST
self
.
side_channel_port
=
(
self
.
side_channel_port
=
(
envs
.
VLLM_NIXL_SIDE_CHANNEL_PORT
+
envs
.
VLLM_NIXL_SIDE_CHANNEL_PORT
+
...
@@ -190,7 +193,7 @@ class NixlConnectorScheduler:
...
@@ -190,7 +193,7 @@ class NixlConnectorScheduler:
# Requests that need to start recv.
# Requests that need to start recv.
# New requests are added by update_state_after_alloc in
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
# the scheduler. Used to make metadata passed to Worker.
self
.
_reqs_need_recv
:
dict
[
str
,
tuple
[
Request
,
list
[
int
]]]
=
{}
self
.
_reqs_need_recv
:
dict
[
ReqId
,
tuple
[
Request
,
list
[
int
]]]
=
{}
def
get_num_new_matched_tokens
(
def
get_num_new_matched_tokens
(
self
,
request
:
"Request"
,
self
,
request
:
"Request"
,
...
@@ -332,19 +335,19 @@ class NixlConnectorWorker:
...
@@ -332,19 +335,19 @@ class NixlConnectorWorker:
# Agent.
# Agent.
self
.
nixl_wrapper
=
NixlWrapper
(
str
(
uuid
.
uuid4
()),
None
)
self
.
nixl_wrapper
=
NixlWrapper
(
str
(
uuid
.
uuid4
()),
None
)
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
self
.
_remote_agents
:
dict
[
str
,
dict
[
int
,
str
]]
=
defaultdict
(
dict
)
self
.
_remote_agents
:
dict
[
EngineId
,
dict
[
int
,
str
]]
=
defaultdict
(
dict
)
# NIXL handshake port.
# NIXL handshake port.
# NOTE(rob): Within a DP group, each DP rank gets its own
# NOTE(rob): Within a DP group, each DP rank gets its own
# base port (which is sent in the KVTransferParams).
# base port (which is sent in the KVTransferParams).
# Each TP rank listens/queries on the base_port + tp_rank.
# Each TP rank listens/queries on the base_port + tp_rank.
self
.
side_channel_port
=
(
self
.
side_channel_port
:
int
=
(
envs
.
VLLM_NIXL_SIDE_CHANNEL_PORT
+
envs
.
VLLM_NIXL_SIDE_CHANNEL_PORT
+
vllm_config
.
parallel_config
.
data_parallel_rank_local
*
vllm_config
.
parallel_config
.
data_parallel_rank_local
*
vllm_config
.
parallel_config
.
tensor_parallel_size
)
vllm_config
.
parallel_config
.
tensor_parallel_size
)
# Metadata.
# Metadata.
self
.
engine_id
=
engine_id
self
.
engine_id
:
EngineId
=
engine_id
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
world_size
=
get_tensor_model_parallel_world_size
()
self
.
world_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_group
=
get_tp_group
()
self
.
tp_group
=
get_tp_group
()
...
@@ -354,7 +357,7 @@ class NixlConnectorWorker:
...
@@ -354,7 +357,7 @@ class NixlConnectorWorker:
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
# rank will still only pull from a single remote TP worker.
# rank will still only pull from a single remote TP worker.
self
.
kv_caches_base_addr
:
dict
[
str
,
list
[
int
]]
=
{}
self
.
kv_caches_base_addr
:
dict
[
EngineId
,
list
[
int
]]
=
{}
# Number of NIXL regions. Currently one region per cache
# Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer)
# (so 1 per layer for MLA, otherwise 2 per layer)
...
@@ -364,23 +367,23 @@ class NixlConnectorWorker:
...
@@ -364,23 +367,23 @@ class NixlConnectorWorker:
# nixl_prepped_dlist_handle.
# nixl_prepped_dlist_handle.
self
.
src_xfer_side_handle
:
int
=
0
self
.
src_xfer_side_handle
:
int
=
0
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
self
.
dst_xfer_side_handles
:
dict
[
str
,
int
]
=
{}
self
.
dst_xfer_side_handles
:
dict
[
EngineId
,
int
]
=
{}
# Map of engine_id -> num_blocks. All ranks in the same deployment will
# Map of engine_id -> num_blocks. All ranks in the same deployment will
# have the same number of blocks.
# have the same number of blocks.
self
.
dst_num_blocks
:
dict
[
str
,
int
]
=
{}
self
.
dst_num_blocks
:
dict
[
EngineId
,
int
]
=
{}
self
.
_registered_descs
:
list
[
Any
]
=
[]
self
.
_registered_descs
:
list
[
Any
]
=
[]
# In progress transfers.
# In progress transfers.
# [req_id -> list[handle]]
# [req_id -> list[handle]]
self
.
_recving_transfers
=
defaultdict
[
str
,
list
[
Transfer
]](
list
)
self
.
_recving_transfers
=
defaultdict
[
ReqId
,
list
[
Transfer
]](
list
)
# Complete transfer tracker. Used by the rank 0 to track finished
# Complete transfer tracker. Used by the rank 0 to track finished
# transactions on ranks 1 to N-1.
# transactions on ranks 1 to N-1.
# [req_id -> count]
# [req_id -> count]
self
.
_done_recving_count
:
defaultdict
[
str
,
self
.
_done_recving_count
:
defaultdict
[
ReqId
,
int
]
=
defaultdict
(
lambda
:
0
)
int
]
=
defaultdict
(
lambda
:
0
)
self
.
_done_sending_count
:
defaultdict
[
str
,
self
.
_done_sending_count
:
defaultdict
[
ReqId
,
int
]
=
defaultdict
(
lambda
:
0
)
int
]
=
defaultdict
(
lambda
:
0
)
# Background thread for establishing new connections.
# Background thread for establishing new connections.
...
@@ -408,10 +411,10 @@ class NixlConnectorWorker:
...
@@ -408,10 +411,10 @@ class NixlConnectorWorker:
self
.
_use_flashinfer
=
attn_backend
==
_Backend
.
FLASHINFER_VLLM_V1
self
.
_use_flashinfer
=
attn_backend
==
_Backend
.
FLASHINFER_VLLM_V1
logger
.
debug
(
"Detected attention backend %s"
,
self
.
backend_name
)
logger
.
debug
(
"Detected attention backend %s"
,
self
.
backend_name
)
self
.
_tp_size
:
dict
[
str
,
int
]
=
{
self
.
engine_id
:
self
.
world_size
}
self
.
_tp_size
:
dict
[
EngineId
,
int
]
=
{
self
.
engine_id
:
self
.
world_size
}
# With heterogeneous TP, P must wait for all assigned D TP workers to
# With heterogeneous TP, P must wait for all assigned D TP workers to
# finish reading before safely freeing the blocks.
# finish reading before safely freeing the blocks.
self
.
consumer_notification_counts_by_req
=
defaultdict
[
str
,
int
](
int
)
self
.
consumer_notification_counts_by_req
=
defaultdict
[
ReqId
,
int
](
int
)
@
staticmethod
@
staticmethod
def
_nixl_handshake_listener
(
metadata
:
NixlAgentMetadata
,
def
_nixl_handshake_listener
(
metadata
:
NixlAgentMetadata
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment