Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b2fac671
Unverified
Commit
b2fac671
authored
Jun 05, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Jun 04, 2025
Browse files
[P/D] Heterogeneous TP (#18833)
Signed-off-by:
nicklucche
<
nlucches@redhat.com
>
parent
23027e2d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
287 additions
and
100 deletions
+287
-100
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
+8
-3
tests/v1/kv_connector/nixl_integration/test_accuracy.py
tests/v1/kv_connector/nixl_integration/test_accuracy.py
+1
-0
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+17
-1
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+243
-95
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+16
-0
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+2
-1
No files found.
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
View file @
b2fac671
...
...
@@ -8,7 +8,9 @@ MODELS=(
# Number of prefill and decode instances to create
NUM_PREFILL_INSTANCES
=
${
NUM_PREFILL_INSTANCES
:-
1
}
# Default to 1
NUM_DECODE_INSTANCES
=
${
NUM_DECODE_INSTANCES
:-
2
}
# Default to 2
NUM_DECODE_INSTANCES
=
${
NUM_DECODE_INSTANCES
:-
1
}
# Default to 1
PREFILLER_TP_SIZE
=
${
PREFILLER_TP_SIZE
:-
1
}
DECODER_TP_SIZE
=
${
DECODER_TP_SIZE
:-
1
}
# Find the git repository root directory
GIT_ROOT
=
$(
git rev-parse
--show-toplevel
)
...
...
@@ -74,9 +76,10 @@ run_tests_for_model() {
for
i
in
$(
seq
0
$((
NUM_PREFILL_INSTANCES-1
))
)
;
do
# Calculate GPU ID - we'll distribute across available GPUs
GPU_ID
=
$((
i
%
$(
get_num_gpus
)
))
# Calculate port number (base port + instance number)
PORT
=
$((
8100
+
i
))
# Calculate side channel port
# Calculate side channel port
. Avoid clash with with TP workers.
SIDE_CHANNEL_PORT
=
$((
5559
+
i
))
echo
"Starting prefill instance
$i
on GPU
$GPU_ID
, port
$PORT
"
...
...
@@ -87,6 +90,7 @@ run_tests_for_model() {
--enforce-eager
\
--disable-log-requests
\
--gpu-memory-utilization 0.2
\
--tensor-parallel-size
$PREFILLER_TP_SIZE
\
--kv-transfer-config '{
\"
kv_connector
\"
:
\"
NixlConnector
\"
,
\"
kv_role
\"
:
\"
kv_both
\"
}'"
if
[
-n
"
$model_args
"
]
;
then
...
...
@@ -109,7 +113,7 @@ run_tests_for_model() {
# Calculate port number (base port + instance number)
PORT
=
$((
8200
+
i
))
# Calculate side channel port
SIDE_CHANNEL_PORT
=
$((
5659
+
i
))
SIDE_CHANNEL_PORT
=
$((
5659
+
i
*
$DECODER_TP_SIZE
))
echo
"Starting decode instance
$i
on GPU
$GPU_ID
, port
$PORT
"
...
...
@@ -119,6 +123,7 @@ run_tests_for_model() {
--enforce-eager
\
--disable-log-requests
\
--gpu-memory-utilization 0.2
\
--tensor-parallel-size
$DECODER_TP_SIZE
\
--kv-transfer-config '{
\"
kv_connector
\"
:
\"
NixlConnector
\"
,
\"
kv_role
\"
:
\"
kv_both
\"
}'"
if
[
-n
"
$model_args
"
]
;
then
...
...
tests/v1/kv_connector/nixl_integration/test_accuracy.py
View file @
b2fac671
...
...
@@ -14,6 +14,7 @@ RTOL = 0.03
# Model-specific expected values
EXPECTED_VALUES
=
{
"Qwen/Qwen3-0.6B"
:
0.41
,
"deepseek-ai/deepseek-vl2-small"
:
0.59
}
SIMPLE_PROMPT
=
"The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means"
,
# noqa: E501
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
b2fac671
...
...
@@ -3,11 +3,12 @@
"""
KV cache helper for store.
"""
import
torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
,
get_current_vllm_config
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
...
...
@@ -90,3 +91,18 @@ class model_aware_kv_ops_helper:
layer
.
self_attn
.
attn
.
_k_scale
,
layer
.
self_attn
.
attn
.
_v_scale
,
)
def
get_kv_connector_cache_layout
():
vllm_config
=
get_current_vllm_config
()
kv_config
=
vllm_config
.
kv_transfer_config
if
vllm_config
.
model_config
is
None
:
logger
.
warning
(
"Unable to detect current VLLM config. "
\
"Defaulting to NHD kv cache layout."
)
else
:
use_mla
=
vllm_config
.
model_config
.
use_mla
if
not
use_mla
and
kv_config
.
kv_connector
==
"NixlConnector"
:
logger
.
info
(
"NixlConnector detected. Setting KV cache "
\
"layout to HND for better xfer performance."
)
return
"HND"
return
"NHD"
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
b2fac671
...
...
@@ -32,6 +32,7 @@ if TYPE_CHECKING:
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.request
import
Request
Transfer
=
tuple
[
int
,
float
]
# (xfer_handle, start_time)
GET_META_MSG
=
b
"get_meta_msg"
logger
=
init_logger
(
__name__
)
...
...
@@ -54,6 +55,8 @@ class NixlAgentMetadata(
agent_metadata
:
bytes
kv_caches_base_addr
:
list
[
int
]
num_blocks
:
int
tp_size
:
int
block_len
:
int
@
dataclass
...
...
@@ -331,10 +334,14 @@ class NixlConnectorWorker:
logger
.
info
(
"Initializing NIXL wrapper"
)
logger
.
info
(
"Initializing NIXL worker %s"
,
engine_id
)
# Config.
self
.
vllm_config
=
vllm_config
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
# Agent.
self
.
nixl_wrapper
=
NixlWrapper
(
str
(
uuid
.
uuid4
()),
None
)
# Map of engine_id -> agent_name.
self
.
_remote_agents
:
dict
[
str
,
str
]
=
{}
# Map of engine_id ->
{rank0:
agent_name
0, rank1: agent_name1..}
.
self
.
_remote_agents
:
dict
[
str
,
dict
[
int
,
str
]
]
=
defaultdict
(
dict
)
# NIXL handshake port.
# NOTE(rob): Within a DP group, each DP rank gets its own
...
...
@@ -354,7 +361,8 @@ class NixlConnectorWorker:
# KV Caches and nixl tracking data.
self
.
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
# Map of engine_id -> kv_caches_base_addr
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
# rank will still only pull from a single remote TP worker.
self
.
kv_caches_base_addr
:
dict
[
str
,
list
[
int
]]
=
{}
# Number of NIXL regions. Currently one region per cache
...
...
@@ -362,19 +370,19 @@ class NixlConnectorWorker:
self
.
num_regions
=
0
self
.
num_layers
=
0
# nixl_prepped_dlist_handle
(int)
.
# nixl_prepped_dlist_handle.
self
.
src_xfer_side_handle
:
int
=
0
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
self
.
dst_xfer_side_handles
:
dict
[
str
,
int
]
=
{}
# Map of engine_id -> num_blocks.
# Map of engine_id -> num_blocks. All ranks in the same deployment will
# have the same number of blocks.
self
.
dst_num_blocks
:
dict
[
str
,
int
]
=
{}
self
.
_registered_descs
:
list
[
Any
]
=
[]
# In progress transfers.
# [req_id -> list[handle]]
self
.
_recving_transfers
:
defaultdict
[
str
,
list
[
Any
]]
=
defaultdict
(
list
[
Any
])
self
.
_recving_transfers
=
defaultdict
[
str
,
list
[
Transfer
]](
list
)
# Complete transfer tracker. Used by the rank 0 to track finished
# transactions on ranks 1 to N-1.
...
...
@@ -395,6 +403,11 @@ class NixlConnectorWorker:
# List of block window sizes for each layer for local attention
self
.
block_window_per_layer
:
list
[
Optional
[
int
]]
=
[]
self
.
_tp_size
:
dict
[
str
,
int
]
=
{
self
.
engine_id
:
self
.
world_size
}
# With heterogeneous TP, P must wait for all assigned D TP workers to
# finish reading before safely freeing the blocks.
self
.
consumer_notification_counts_by_req
=
defaultdict
[
str
,
int
](
int
)
@
staticmethod
def
_nixl_handshake_listener
(
metadata
:
NixlAgentMetadata
,
ready_event
:
threading
.
Event
,
base_port
:
int
,
...
...
@@ -426,27 +439,44 @@ class NixlConnectorWorker:
"""Do a NIXL handshake with a remote instance."""
start_time
=
time
.
perf_counter
()
# NOTE(rob): we need each tp_rank to have a unique port.
# This is a hack to keep us moving. We will switch when
# we switch to HTTP-based NIXL metadata exchange.
path
=
make_zmq_path
(
"tcp"
,
host
,
port
+
self
.
tp_rank
)
logger
.
debug
(
"Querying metadata on path: %s"
,
path
)
with
zmq_ctx
(
zmq
.
REQ
,
path
)
as
sock
:
# Send query for the request.
sock
.
send
(
GET_META_MSG
)
metadata_bytes
=
sock
.
recv
()
decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlAgentMetadata
)
metadata
=
decoder
.
decode
(
metadata_bytes
)
got_metadata_time
=
time
.
perf_counter
()
# Register Remote agent.
self
.
add_remote_agent
(
metadata
)
setup_agent_time
=
time
.
perf_counter
()
# NOTE(rob): we need each rank to have a unique port. This is
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
logger
.
debug
(
"NIXL handshake: get metadata took: %s"
,
got_metadata_time
-
start_time
)
logger
.
debug
(
"NIXL handshake: add agent took: %s"
,
setup_agent_time
-
got_metadata_time
)
def
handshake
(
path
:
str
,
rank
:
int
)
->
NixlAgentMetadata
:
# Send query for the request.
with
zmq_ctx
(
zmq
.
REQ
,
path
)
as
sock
:
sock
.
send
(
GET_META_MSG
)
metadata_bytes
=
sock
.
recv
()
decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlAgentMetadata
)
metadata
=
decoder
.
decode
(
metadata_bytes
)
got_metadata_time
=
time
.
perf_counter
()
# Register Remote agent.
self
.
add_remote_agent
(
metadata
,
rank
)
setup_agent_time
=
time
.
perf_counter
()
logger
.
debug
(
"NIXL handshake: get metadata took: %s"
,
got_metadata_time
-
start_time
)
logger
.
debug
(
"NIXL handshake: add agent took: %s"
,
setup_agent_time
-
got_metadata_time
)
return
metadata
# Handshake with remote agent-rank0 first to get the tp_size of remote
path
=
make_zmq_path
(
"tcp"
,
host
,
port
)
logger
.
debug
(
"Querying master rank metadata on path: %s"
,
path
)
metadata
=
handshake
(
path
,
0
)
# Handshake only with the other TP remote the current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
tp_ratio
=
self
.
_tp_size
[
self
.
engine_id
]
//
metadata
.
tp_size
p_remote_rank
=
self
.
tp_rank
//
tp_ratio
if
p_remote_rank
>
0
:
path
=
make_zmq_path
(
"tcp"
,
host
,
port
+
p_remote_rank
)
logger
.
debug
(
"Querying metadata on path: %s at remote rank %s"
,
path
,
p_remote_rank
)
_
=
handshake
(
path
,
p_remote_rank
)
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]):
"""Register the KV Cache data in nixl."""
...
...
@@ -455,24 +485,34 @@ class NixlConnectorWorker:
kv_elem_size
=
first_kv_cache
.
element_size
()
# TODO(tms): Find a more robust way to detect and handle MLA
use_mla
=
len
(
first_kv_cache
.
shape
)
==
3
if
use_mla
:
self
.
use_mla
=
len
(
first_kv_cache
.
shape
)
==
3
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
# KV memory layout is HND, as opposed to the default NHD. Note that it
# will only affects the strides. For MLA instead, we make require no
# such thing and resort to the standard layout.
if
self
.
use_mla
:
# MLA case.
self
.
num_blocks
=
first_kv_cache
.
shape
[
0
]
block_rank
=
2
# [block_size, latent_dim]
block_shape
=
first_kv_cache
.
shape
[
-
block_rank
:]
block_size
,
kv_latent_dim
=
block_shape
self
.
slot_size_bytes
=
kv_elem_size
*
kv_latent_dim
else
:
# [2 (k and v), num_blocks,
...
]
# [2 (k and v), num_blocks,
block_size, kv_heads, head_dim
]
self
.
num_blocks
=
first_kv_cache
.
shape
[
1
]
block_rank
=
3
# [block_size, kv_heads, head_dim]
block_shape
=
first_kv_cache
.
shape
[
-
block_rank
:]
block_size
,
n_kv_heads
,
head_dim
=
block_shape
# head size in bytes.
self
.
slot_size_bytes
=
kv_elem_size
*
n_kv_heads
*
head_dim
assert
block_size
==
self
.
block_size
# TODO(tms): self.block_len needs to be per-layer for sliding window,
# hybrid attn, etc
# block size in bytes
self
.
block_len
=
kv_elem_size
*
math
.
prod
(
block_shape
)
logger
.
debug
(
"Registering KV_Caches. use_mla: %s, shape %s"
,
use_mla
,
first_kv_cache
.
shape
)
logger
.
debug
(
"Registering KV_Caches. use_mla: %s, shape %s"
,
self
.
use_mla
,
first_kv_cache
.
shape
)
logger
.
debug
(
"num_blocks: %s, block_shape: %s"
,
self
.
num_blocks
,
block_shape
)
logger
.
debug
(
"Per layer kv cache size: %s"
,
first_kv_cache
.
shape
)
...
...
@@ -489,7 +529,7 @@ class NixlConnectorWorker:
# (roughly 8KB vs 5KB).
for
cache_or_caches
in
kv_caches
.
values
():
# Normalize to always be a list of caches
cache_list
=
[
cache_or_caches
]
if
use_mla
else
cache_or_caches
cache_list
=
[
cache_or_caches
]
if
self
.
use_mla
else
cache_or_caches
for
cache
in
cache_list
:
base_addr
=
cache
.
data_ptr
()
region_len
=
self
.
num_blocks
*
self
.
block_len
...
...
@@ -524,16 +564,37 @@ class NixlConnectorWorker:
logger
.
debug
(
"Registering descs: %s"
,
caches_data
)
self
.
nixl_wrapper
.
register_memory
(
descs
)
logger
.
debug
(
"Done registering descs"
)
self
.
_registered_descs
.
append
(
descs
)
# Register local/src descr for NIXL xfer.
blocks_data
=
[]
for
base_addr
in
self
.
kv_caches_base_addr
[
self
.
engine_id
]:
# NOTE With heter-TP, more blocks are prepared than what are
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
# could create fewer, but then _get_block_descs_ids needs to
# select agent_meta.num_blocks instead of self.num_blocks for
# local descr, and that makes handling regular flow less clean.
for
block_id
in
range
(
self
.
num_blocks
):
block_offset
=
block_id
*
self
.
block_len
addr
=
base_addr
+
block_offset
# (addr, len, device id)
blocks_data
.
append
((
addr
,
self
.
block_len
,
self
.
tp_rank
))
logger
.
debug
(
"Created %s blocks for src engine %s and rank %s"
,
len
(
blocks_data
),
self
.
engine_id
,
self
.
tp_rank
)
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
"VRAM"
)
# NIXL_INIT_AGENT to be used for preparations of local descs.
self
.
src_xfer_side_handle
=
self
.
nixl_wrapper
.
prep_xfer_dlist
(
"NIXL_INIT_AGENT"
,
descs
)
# After KV Caches registered, listen for new connections.
metadata
=
NixlAgentMetadata
(
engine_id
=
self
.
engine_id
,
agent_metadata
=
self
.
nixl_wrapper
.
get_agent_metadata
(),
kv_caches_base_addr
=
self
.
kv_caches_base_addr
[
self
.
engine_id
],
num_blocks
=
self
.
num_blocks
,
)
tp_size
=
self
.
world_size
,
block_len
=
self
.
block_len
)
ready_event
=
threading
.
Event
()
self
.
_nixl_handshake_listener_t
=
threading
.
Thread
(
target
=
self
.
_nixl_handshake_listener
,
...
...
@@ -543,50 +604,123 @@ class NixlConnectorWorker:
self
.
_nixl_handshake_listener_t
.
start
()
ready_event
.
wait
()
def
add_remote_agent
(
self
,
nixl_agent_meta
:
NixlAgentMetadata
):
def
add_remote_agent
(
self
,
nixl_agent_meta
:
NixlAgentMetadata
,
remote_tp_rank
:
int
=
0
):
"""
Add the remote NIXL agent and prepare the descriptors for reading cache
blocks from remote.
In particular, handle both homogeneous and heterogeneous TP. The former
requires local rank_i to read from remote rank_i.
The latter, assuming D.world_size > P.world_size, requires that two or
more local TP worker share the xfer from a single TP worker.
Here's an example:
rank_offset p_remote_tp_rank
(kv split no)
--------------------------------
0 0 Worker0 ---- 1st half of KV ----> Worker0 [ KV Cache ]
/
1 0 Worker1 ---- 2nd half of KV -----/
0 1 Worker2 ---- 1st half of KV ----> Worker1 [ KV Cache ]
/
1 1 Worker3 ---- 2nd half of KV -----/
Decoder TP workers Prefix TP workers
(world_size=4) (world_size=2)
tp_ratio = 4 // 2 = 2
Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, kv_heads, block_size, head_dim]
then D-Worker_j has [2, num_blocksD, kv_heads//tp_ratio, block_size, head_dim]. Mind the "HND" layout format.
Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio
first heads from all the slots of all the blocks. D-Worker1 will do the same, but reading the second split
along the kv_heads dimension, and so forth until "tp_ratio" D TP workers have pulled from P-Worker0.
Note that the above will also hold true for the homogeneous TP case, where tp_ratio evaluates to 1.
Regarding MLA case, the cache is replicated across TP workers so the rank_offset will just always be 0
so that the whole cache is shared by "tp_ratio" D TP workers.
"""
# noqa: E501
engine_id
=
nixl_agent_meta
.
engine_id
assert
engine_id
!=
self
.
engine_id
,
"Conflict engine id found!"
if
engine_id
in
self
.
_remote_agents
:
# TODO re-evaluate refreshing for scaling/recovery
if
remote_tp_rank
in
self
.
_remote_agents
.
get
(
engine_id
,
())
:
return
self
.
_remote_agents
[
engine_id
]
=
self
.
nixl_wrapper
.
add_remote_agent
(
nixl_agent_meta
.
agent_metadata
)
self
.
kv_caches_base_addr
[
engine_id
]
=
nixl_agent_meta
.
kv_caches_base_addr
if
engine_id
in
self
.
_tp_size
:
assert
self
.
_tp_size
[
engine_id
]
==
nixl_agent_meta
.
tp_size
else
:
self
.
_tp_size
[
engine_id
]
=
nixl_agent_meta
.
tp_size
self
.
_remote_agents
[
engine_id
][
remote_tp_rank
]
=
self
.
nixl_wrapper
.
add_remote_agent
(
nixl_agent_meta
.
agent_metadata
)
# Number of D TP workers reading from a single P TP worker. This is
# 1 when P and D `--tensor-parallel-size` match.
assert
self
.
_tp_size
[
self
.
engine_id
]
%
self
.
_tp_size
[
engine_id
]
==
0
,
\
"Local TP size must be divisible by remote TP size."
tp_ratio
=
self
.
_tp_size
[
self
.
engine_id
]
//
self
.
_tp_size
[
engine_id
]
assert
tp_ratio
>
0
,
"Decode TP cannot be smaller than"
" prefill TP"
if
self
.
use_mla
:
# With MLA the only difference is in the number of blocks.
remote_block_size
=
nixl_agent_meta
.
block_len
/
(
self
.
slot_size_bytes
)
assert
self
.
block_len
==
nixl_agent_meta
.
block_len
else
:
remote_block_size
=
nixl_agent_meta
.
block_len
/
(
self
.
slot_size_bytes
*
tp_ratio
)
assert
nixl_agent_meta
.
block_len
==
self
.
block_len
*
tp_ratio
,
\
"Remote P worker KV layer cache must be of shape [2, N,
\
local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
# Create src descs and xfer side handles.
blocks_data
=
[]
for
base_addr
in
self
.
kv_caches_base_addr
[
self
.
engine_id
]:
for
block_id
in
range
(
self
.
num_blocks
):
block_offset
=
block_id
*
self
.
block_len
# (addr, len, device id)
blocks_data
.
append
(
(
base_addr
+
block_offset
,
self
.
block_len
,
self
.
tp_rank
))
logger
.
debug
(
"Created %s blocks for src engine %s and tp_rank %s"
,
len
(
blocks_data
),
self
.
engine_id
,
self
.
tp_rank
)
assert
self
.
block_size
==
remote_block_size
,
"Remote P worker with
\
different block size is not supported"
# Register with NIXL.
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
"VRAM"
)
self
.
src_xfer_side_handle
=
self
.
nixl_wrapper
.
prep_xfer_dlist
(
"NIXL_INIT_AGENT"
,
descs
)
assert
self
.
num_blocks
>=
nixl_agent_meta
.
num_blocks
# Create dst descs and xfer side handles. TP workers have same #blocks.
if
engine_id
in
self
.
dst_num_blocks
:
assert
self
.
dst_num_blocks
[
engine_id
]
==
nixl_agent_meta
.
num_blocks
else
:
self
.
dst_num_blocks
[
engine_id
]
=
nixl_agent_meta
.
num_blocks
# Create dst descs and xfer side handles.
self
.
dst_num_blocks
[
engine_id
]
=
nixl_agent_meta
.
num_blocks
blocks_data
=
[]
for
base_addr
in
self
.
kv_caches_base_addr
[
engine_id
]:
for
block_id
in
range
(
nixl_agent_meta
.
num_blocks
):
block_offset
=
block_id
*
self
.
block_len
# (addr, len, device id)
blocks_data
.
append
(
(
base_addr
+
block_offset
,
self
.
block_len
,
self
.
tp_rank
))
logger
.
debug
(
"Created %s blocks for dst engine %s and tp_rank %s"
,
len
(
blocks_data
),
engine_id
,
self
.
tp_rank
)
# With homogeneous TP, D pulls the whole kv cache from corresponding
# rank. With heterogeneous TP, prepare the descriptors by splitting the
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
p_remote_tp_rank
=
self
.
tp_rank
//
tp_ratio
# Only register the remote's descriptors if current rank pulls from it.
if
p_remote_tp_rank
==
remote_tp_rank
:
self
.
kv_caches_base_addr
[
engine_id
]
=
nixl_agent_meta
.
kv_caches_base_addr
rank_offset
=
self
.
tp_rank
%
tp_ratio
*
self
.
block_len
\
if
not
self
.
use_mla
else
0
# Register all remote blocks, but only the corresponding kv heads.
for
base_addr
in
nixl_agent_meta
.
kv_caches_base_addr
:
for
block_id
in
range
(
nixl_agent_meta
.
num_blocks
):
block_offset
=
block_id
*
nixl_agent_meta
.
block_len
# For each block, grab the heads chunk belonging to rank_i
# of size remote_nheads // tp_ratio, which correspond to
# self.block_len == remote_block_len//tp_ratio bytes.
addr
=
base_addr
+
block_offset
+
rank_offset
# (addr, len, device id)
blocks_data
.
append
((
addr
,
self
.
block_len
,
remote_tp_rank
))
logger
.
debug
(
"Created %s blocks for dst engine %s with remote rank %s and "
\
"local rank %s"
,
len
(
blocks_data
),
engine_id
,
remote_tp_rank
,
self
.
tp_rank
)
# Register with NIXL.
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
"VRAM"
)
self
.
dst_xfer_side_handles
[
engine_id
]
=
self
.
nixl_wrapper
.
prep_xfer_dlist
(
self
.
_remote_agents
[
engine_id
],
descs
)
# Register with NIXL.
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
"VRAM"
)
self
.
dst_xfer_side_handles
[
engine_id
]
=
self
.
nixl_wrapper
.
prep_xfer_dlist
(
self
.
_remote_agents
[
engine_id
]
[
remote_tp_rank
]
,
descs
)
def
get_finished
(
self
)
->
tuple
[
set
[
str
],
set
[
str
]]:
"""
...
...
@@ -654,16 +788,25 @@ class NixlConnectorWorker:
return
done_sending
,
done_recving
def
_get_new_notifs
(
self
)
->
set
[
str
]:
"""Get req_ids which got a remote xfer message."""
"""
Get req_ids which got a remote xfer message. When multiple consumers
are reading from the same producer (heterogeneous TP scenario), wait
for all consumers to be done pulling.
"""
notified_req_ids
:
set
[
str
]
=
set
()
for
req_ids
in
self
.
nixl_wrapper
.
get_new_notifs
().
values
():
for
req_id
in
req_ids
:
assert
req_id
not
in
notified_req_ids
notified_req_ids
.
add
(
req_id
.
decode
(
"utf-8"
))
for
notifs
in
self
.
nixl_wrapper
.
get_new_notifs
().
values
():
for
notif
in
notifs
:
req_id
,
tp_ratio
=
notif
.
decode
(
"utf-8"
).
rsplit
(
":"
,
1
)
self
.
consumer_notification_counts_by_req
[
req_id
]
+=
1
# Wait all consumers (D) to be done reading before freeing.
if
self
.
consumer_notification_counts_by_req
[
req_id
]
==
int
(
tp_ratio
):
notified_req_ids
.
add
(
req_id
)
del
self
.
consumer_notification_counts_by_req
[
req_id
]
return
notified_req_ids
def
_pop_done_transfers
(
self
,
transfers
:
dict
[
str
,
list
[
int
]])
->
set
[
str
]:
def
_pop_done_transfers
(
self
,
transfers
:
dict
[
str
,
list
[
tuple
[
int
,
float
]]])
->
set
[
str
]:
"""
Pop completed xfers by checking for DONE state.
Args:
...
...
@@ -673,23 +816,17 @@ class NixlConnectorWorker:
"""
done_req_ids
:
set
[
str
]
=
set
()
for
req_id
,
handles
in
list
(
transfers
.
items
()):
running_reqs
=
[]
for
handle
in
handles
:
for
handle
,
xfer_stime
in
handles
:
xfer_state
=
self
.
nixl_wrapper
.
check_xfer_state
(
handle
)
if
xfer_state
==
"DONE"
:
# TODO ptarasiewicz: why abort is throwing errors?
# self.nixl_wrapper.release_xfer_handle(handle)
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
done_req_ids
.
add
(
req_id
)
del
transfers
[
req_id
]
elif
xfer_state
==
"PROC"
:
continue
if
xfer_state
==
"PROC"
:
running_reqs
.
append
(
handle
)
else
:
raise
RuntimeError
(
"Transfer failed with state %s"
,
xfer_state
)
if
len
(
running_reqs
)
==
0
:
done_req_ids
.
add
(
req_id
)
del
transfers
[
req_id
]
else
:
transfers
[
req_id
]
=
running_reqs
return
done_req_ids
def
start_load_kv
(
self
,
metadata
:
NixlConnectorMetadata
):
...
...
@@ -735,13 +872,19 @@ class NixlConnectorWorker:
# saturate IB with heterogeneous TP sizes. We should remove the staging
# blocks until we are ready.
# Number of D TP workers that will read from dst P. Propagate tp_ratio
# on notification so that dst worker can wait before freeing blocks.
tp_ratio
=
self
.
_tp_size
[
self
.
engine_id
]
//
self
.
_tp_size
[
dst_engine_id
]
notif_id
=
f
"
{
request_id
}
:
{
tp_ratio
}
"
.
encode
()
# Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need.
num_local_blocks
=
len
(
local_block_ids
)
if
num_local_blocks
==
0
:
agent_name
=
self
.
_remote_agents
[
dst_engine_id
]
self
.
nixl_wrapper
.
send_notif
(
agent_name
,
notif_msg
=
request_id
.
encode
(
"utf-8"
)
)
remote_rank
=
self
.
tp_rank
//
tp_ratio
agent_name
=
self
.
_remote_agents
[
dst_engine_id
][
remote_rank
]
self
.
nixl_wrapper
.
send_notif
(
agent_name
,
notif_msg
=
notif_id
)
return
# Partial prefix cache hit: just read uncomputed blocks.
...
...
@@ -754,6 +897,10 @@ class NixlConnectorWorker:
local_xfer_side_handle
=
self
.
src_xfer_side_handle
remote_xfer_side_handle
=
self
.
dst_xfer_side_handles
[
dst_engine_id
]
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
# workers will issue xfers to parts of the P worker remote kv caches.
# Get descs ids.
local_block_descs_ids
:
list
[
int
]
=
[]
remote_block_descs_ids
:
list
[
int
]
=
[]
...
...
@@ -797,14 +944,16 @@ class NixlConnectorWorker:
local_block_descs_ids
,
remote_xfer_side_handle
,
remote_block_descs_ids
,
notif_msg
=
request_id
.
encode
(
"utf-8"
)
,
notif_msg
=
notif_id
,
)
# Begin async xfer.
self
.
nixl_wrapper
.
transfer
(
handle
)
# Use handle to check completion in future step().
self
.
_recving_transfers
[
request_id
].
append
(
handle
)
# TODO (NickLucche) surface xfer elapsed time
self
.
_recving_transfers
[
request_id
].
append
(
(
handle
,
time
.
perf_counter
()))
def
_get_block_descs_ids
(
self
,
engine_id
:
str
,
...
...
@@ -815,7 +964,6 @@ class NixlConnectorWorker:
If layer_idx is provided, we use the region_ids for the given layer.
Otherwise, we use all regions.
"""
if
layer_idx
is
None
:
region_ids
=
range
(
self
.
num_regions
)
else
:
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
b2fac671
...
...
@@ -16,6 +16,8 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states
from
vllm.attention.utils.fa_utils
import
(
flash_attn_supports_fp8
,
get_flash_attn_version
)
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
get_kv_connector_cache_layout
)
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
...
...
@@ -70,6 +72,20 @@ class FlashAttentionBackend(AttentionBackend):
raise
ValueError
(
"Block size must be a multiple of 16."
)
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
get_kv_cache_stride_order
()
->
tuple
[
int
,
...]:
# NOTE When running disaggregated PD with NIXL, HND layout is used for
# faster transfer. `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout
=
get_kv_connector_cache_layout
()
if
cache_layout
==
"NHD"
:
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
elif
cache_layout
==
"HND"
:
stride_order
=
(
0
,
1
,
3
,
2
,
4
)
else
:
raise
ValueError
(
"Unknown cache layout format %s."
,
cache_layout
)
return
stride_order
@
dataclass
class
FlashAttentionMetadata
:
...
...
vllm/worker/worker_base.py
View file @
b2fac671
...
...
@@ -597,7 +597,8 @@ class WorkerWrapperBase:
def
initialize_from_config
(
self
,
kv_cache_configs
:
List
[
Any
])
->
None
:
kv_cache_config
=
kv_cache_configs
[
self
.
rpc_rank
]
self
.
worker
.
initialize_from_config
(
kv_cache_config
)
# type: ignore
with
set_current_vllm_config
(
self
.
vllm_config
):
self
.
worker
.
initialize_from_config
(
kv_cache_config
)
# type: ignore
def
init_device
(
self
):
with
set_current_vllm_config
(
self
.
vllm_config
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment