Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fd195b19
Unverified
Commit
fd195b19
authored
May 16, 2025
by
Michael Goin
Committed by
GitHub
May 16, 2025
Browse files
[V1][P/D] Local attention optimization for NIXL (#18170)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
fabe89bb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
90 additions
and
11 deletions
+90
-11
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+90
-11
No files found.
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
fd195b19
...
@@ -96,7 +96,8 @@ class NixlConnector(KVConnectorBase_V1):
...
@@ -96,7 +96,8 @@ class NixlConnector(KVConnectorBase_V1):
self
.
connector_worker
:
Optional
[
NixlConnectorWorker
]
=
None
self
.
connector_worker
:
Optional
[
NixlConnectorWorker
]
=
None
elif
role
==
KVConnectorRole
.
WORKER
:
elif
role
==
KVConnectorRole
.
WORKER
:
self
.
connector_scheduler
=
None
self
.
connector_scheduler
=
None
self
.
connector_worker
=
NixlConnectorWorker
(
str
(
self
.
engine_id
))
self
.
connector_worker
=
NixlConnectorWorker
(
vllm_config
,
str
(
self
.
engine_id
))
############################################################
############################################################
# Scheduler Side Methods
# Scheduler Side Methods
...
@@ -302,7 +303,7 @@ class NixlConnectorScheduler:
...
@@ -302,7 +303,7 @@ class NixlConnectorScheduler:
class
NixlConnectorWorker
:
class
NixlConnectorWorker
:
"""Implementation of Worker side methods"""
"""Implementation of Worker side methods"""
def
__init__
(
self
,
engine_id
:
str
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_id
:
str
):
if
NixlWrapper
is
None
:
if
NixlWrapper
is
None
:
logger
.
error
(
"NIXL is not available"
)
logger
.
error
(
"NIXL is not available"
)
raise
RuntimeError
(
"NIXL is not available"
)
raise
RuntimeError
(
"NIXL is not available"
)
...
@@ -329,6 +330,7 @@ class NixlConnectorWorker:
...
@@ -329,6 +330,7 @@ class NixlConnectorWorker:
# Number of NIXL regions. Currently one region per cache
# Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer)
# (so 1 per layer for MLA, otherwise 2 per layer)
self
.
num_regions
=
0
self
.
num_regions
=
0
self
.
num_layers
=
0
# nixl_prepped_dlist_handle (int).
# nixl_prepped_dlist_handle (int).
self
.
src_xfer_side_handle
:
int
=
0
self
.
src_xfer_side_handle
:
int
=
0
...
@@ -355,6 +357,14 @@ class NixlConnectorWorker:
...
@@ -355,6 +357,14 @@ class NixlConnectorWorker:
# Background thread for establishing new connections.
# Background thread for establishing new connections.
self
.
_nixl_handshake_listener_t
:
Optional
[
threading
.
Thread
]
=
None
self
.
_nixl_handshake_listener_t
:
Optional
[
threading
.
Thread
]
=
None
self
.
vllm_config
=
vllm_config
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
# List of block window sizes for each layer for local attention
self
.
block_window_per_layer
:
list
[
Optional
[
int
]]
=
[]
@
staticmethod
@
staticmethod
def
_nixl_handshake_listener
(
metadata
:
NixlAgentMetadata
,
def
_nixl_handshake_listener
(
metadata
:
NixlAgentMetadata
,
ready_event
:
threading
.
Event
,
rank
:
int
):
ready_event
:
threading
.
Event
,
rank
:
int
):
...
@@ -465,6 +475,27 @@ class NixlConnectorWorker:
...
@@ -465,6 +475,27 @@ class NixlConnectorWorker:
kv_caches_base_addr
.
append
(
base_addr
)
kv_caches_base_addr
.
append
(
base_addr
)
self
.
kv_caches_base_addr
[
self
.
engine_id
]
=
kv_caches_base_addr
self
.
kv_caches_base_addr
[
self
.
engine_id
]
=
kv_caches_base_addr
self
.
num_regions
=
len
(
caches_data
)
self
.
num_regions
=
len
(
caches_data
)
self
.
num_layers
=
len
(
self
.
kv_caches
.
keys
())
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
if
self
.
vllm_config
.
model_config
.
hf_config
.
model_type
==
"llama4"
:
from
transformers
import
Llama4TextConfig
assert
isinstance
(
self
.
vllm_config
.
model_config
.
hf_text_config
,
Llama4TextConfig
)
llama4_config
=
self
.
vllm_config
.
model_config
.
hf_text_config
no_rope_layers
=
llama4_config
.
no_rope_layers
chunk_size
=
llama4_config
.
attention_chunk_size
chunk_block_size
=
math
.
ceil
(
chunk_size
/
self
.
block_size
)
for
layer_idx
in
range
(
self
.
num_layers
):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention
=
no_rope_layers
[
layer_idx
]
!=
0
block_window
=
chunk_block_size
if
is_local_attention
else
None
self
.
block_window_per_layer
.
append
(
block_window
)
logger
.
debug
(
"Llama 4 block window per layer mapping: %s"
,
self
.
block_window_per_layer
)
assert
len
(
self
.
block_window_per_layer
)
==
self
.
num_layers
descs
=
self
.
nixl_wrapper
.
get_reg_descs
(
caches_data
,
"VRAM"
)
descs
=
self
.
nixl_wrapper
.
get_reg_descs
(
caches_data
,
"VRAM"
)
logger
.
debug
(
"Registering descs: %s"
,
caches_data
)
logger
.
debug
(
"Registering descs: %s"
,
caches_data
)
...
@@ -699,10 +730,39 @@ class NixlConnectorWorker:
...
@@ -699,10 +730,39 @@ class NixlConnectorWorker:
remote_xfer_side_handle
=
self
.
dst_xfer_side_handles
[
dst_engine_id
]
remote_xfer_side_handle
=
self
.
dst_xfer_side_handles
[
dst_engine_id
]
# Get descs ids.
# Get descs ids.
local_block_descs_ids
:
list
[
int
]
=
[]
remote_block_descs_ids
:
list
[
int
]
=
[]
if
not
self
.
block_window_per_layer
:
# Default case: assume global attention
remote_block_descs_ids
=
self
.
_get_block_descs_ids
(
remote_block_descs_ids
=
self
.
_get_block_descs_ids
(
dst_engine_id
,
remote_block_ids
)
dst_engine_id
,
remote_block_ids
)
local_block_descs_ids
=
self
.
_get_block_descs_ids
(
local_block_descs_ids
=
self
.
_get_block_descs_ids
(
self
.
engine_id
,
local_block_ids
)
self
.
engine_id
,
local_block_ids
)
else
:
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
for
layer_idx
,
block_window
in
enumerate
(
self
.
block_window_per_layer
):
# For each layer:
if
block_window
is
None
:
# If not chunked, we just use the
# full block lists (global attention)
layer_local_block_ids
=
local_block_ids
layer_remote_block_ids
=
remote_block_ids
else
:
# If chunked, get the last block_window blocks
layer_local_block_ids
=
local_block_ids
[
-
block_window
:]
layer_remote_block_ids
=
remote_block_ids
[
-
block_window
:]
# Get descs ids for the layer.
layer_local_desc_ids
=
self
.
_get_block_descs_ids
(
self
.
engine_id
,
layer_local_block_ids
,
layer_idx
)
layer_remote_desc_ids
=
self
.
_get_block_descs_ids
(
dst_engine_id
,
layer_remote_block_ids
,
layer_idx
)
local_block_descs_ids
.
extend
(
layer_local_desc_ids
)
remote_block_descs_ids
.
extend
(
layer_remote_desc_ids
)
assert
len
(
local_block_descs_ids
)
==
len
(
remote_block_descs_ids
)
assert
len
(
local_block_descs_ids
)
==
len
(
remote_block_descs_ids
)
# Prepare transfer with Nixl.
# Prepare transfer with Nixl.
...
@@ -721,12 +781,31 @@ class NixlConnectorWorker:
...
@@ -721,12 +781,31 @@ class NixlConnectorWorker:
# Use handle to check completion in future step().
# Use handle to check completion in future step().
self
.
_recving_transfers
[
request_id
].
append
(
handle
)
self
.
_recving_transfers
[
request_id
].
append
(
handle
)
def
_get_block_descs_ids
(
self
,
engine_id
:
str
,
def
_get_block_descs_ids
(
self
,
block_ids
:
list
[
int
])
->
list
[
int
]:
engine_id
:
str
,
"""Get the descs ids for a set of block ids."""
block_ids
:
list
[
int
],
layer_idx
:
Optional
[
int
]
=
None
)
->
list
[
int
]:
"""
Get the descs ids for a set of block ids.
If layer_idx is provided, we use the region_ids for the given layer.
Otherwise, we use all regions.
"""
# range(1) for MLA, range(2) otherwise.
if
layer_idx
is
None
:
region_ids
=
range
(
self
.
num_regions
)
region_ids
=
range
(
self
.
num_regions
)
else
:
assert
layer_idx
<
self
.
num_layers
if
self
.
num_layers
<
self
.
num_regions
:
# If we have more regions than layers, we assume that
# the regions are organized as [K0, V0, K1, V1, ...]
# and we select K_i and V_i
assert
2
*
self
.
num_layers
==
self
.
num_regions
region_ids
=
range
(
2
*
layer_idx
,
2
*
layer_idx
+
2
)
else
:
# Otherwise, we assume we have MLA and select i-th layer
assert
self
.
num_layers
==
self
.
num_regions
region_ids
=
range
(
layer_idx
,
layer_idx
+
1
)
num_blocks
=
self
.
dst_num_blocks
[
engine_id
]
num_blocks
=
self
.
dst_num_blocks
[
engine_id
]
# Compute the desc ids for each block.
# Compute the desc ids for each block.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment