Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bc3700e0
Unverified
Commit
bc3700e0
authored
Dec 18, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Dec 18, 2025
Browse files
[NIXL] Support P tensor-parallel-size > D tensor-parallel-size (#27274)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
fd8afdf3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
556 additions
and
212 deletions
+556
-212
tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh
...nnector/nixl_integration/tp_config_sweep_accuracy_test.sh
+3
-0
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+223
-22
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+40
-22
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+290
-168
No files found.
tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh
View file @
bc3700e0
...
...
@@ -8,9 +8,12 @@ SCRIPT="v1/kv_connector/nixl_integration/run_accuracy_test.sh"
configs
=(
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2"
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2"
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1"
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA case
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA+P-TP1, D-DPEP=2 (TP=1)
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA+P-TP2, D-DPEP=2 (TP=1)
)
run_tests
()
{
...
...
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
bc3700e0
...
...
@@ -391,6 +391,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_hand_shake_latency
=
hand_shake_latency
self
.
kv_cache_layout
=
kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
self
.
src_xfer_handles_by_block_size
=
{
self
.
block_size
:
1
}
def
_nixl_handshake
(
self
,
host
:
str
,
port
:
int
,
remote_tp_size
:
int
,
expected_engine_id
:
str
...
...
@@ -407,22 +409,43 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
assert
expected_engine_id
==
self
.
REMOTE_ENGINE_ID
remote_agent_name
=
self
.
add_remote_agent
(
NixlAgentMetadata
(
engine_id
=
self
.
REMOTE_ENGINE_ID
,
agent_metadata
=
FakeNixlWrapper
.
AGENT_METADATA
,
kv_caches_base_addr
=
[
0
],
device_id
=
0
,
num_blocks
=
1
,
block_lens
=
self
.
block_len_per_layer
,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout
=
"HND"
,
block_size
=
self
.
block_size
,
),
remote_tp_size
=
remote_tp_size
,
)
return
{
0
:
remote_agent_name
}
# Adjust remote block length metadata to satisfy heterogeneous TP
# invariants enforced during handshake validation.
remote_block_lens
=
list
(
self
.
block_len_per_layer
)
tp_ratio
=
self
.
kv_topo
.
tp_ratio
(
remote_tp_size
)
if
remote_tp_size
>
self
.
world_size
:
# P TP > D TP case, block_len of remote is smaller
remote_block_lens
=
[
block_len
//
(
-
tp_ratio
)
for
block_len
in
remote_block_lens
]
elif
remote_tp_size
<
self
.
world_size
:
remote_block_lens
=
[
block_len
*
tp_ratio
for
block_len
in
remote_block_lens
]
# When remote tp_size > local tp_size, handshake with multiple
# remote ranks.
num_hanshakes
=
1
if
tp_ratio
>
0
else
-
tp_ratio
remote_agents
:
dict
[
int
,
str
]
=
{}
for
remote_tp_rank
in
range
(
num_hanshakes
):
remote_agent_name
=
self
.
add_remote_agent
(
NixlAgentMetadata
(
engine_id
=
self
.
REMOTE_ENGINE_ID
,
agent_metadata
=
FakeNixlWrapper
.
AGENT_METADATA
,
kv_caches_base_addr
=
[
0
],
device_id
=
remote_tp_rank
,
num_blocks
=
1
,
block_lens
=
remote_block_lens
,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout
=
"HND"
,
block_size
=
self
.
block_size
,
),
remote_tp_rank
=
remote_tp_rank
,
remote_tp_size
=
remote_tp_size
,
)
remote_agents
[
remote_tp_rank
]
=
remote_agent_name
return
remote_agents
class
TestNixlHandshake
:
...
...
@@ -453,7 +476,13 @@ class TestNixlHandshake:
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
assert
isinstance
(
connector
.
connector_worker
.
nixl_wrapper
,
FakeNixlWrapper
)
connector
.
connector_worker
.
nixl_wrapper
.
set_cycles_before_xfer_done
(
3
)
worker
=
connector
.
connector_worker
worker
.
nixl_wrapper
.
set_cycles_before_xfer_done
(
3
)
# simulate handshake
worker
.
dst_xfer_side_handles
=
{
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
:
{
0
:
1
}
}
worker
.
kv_cache_layout
=
"HND"
num_xfers
=
4
while
True
:
# For the same request_id, initiate multiple xfers across different
...
...
@@ -567,6 +596,171 @@ class TestNixlHandshake:
return
raise
TimeoutError
(
"Took too long to complete async handshake."
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
,
)
@
pytest
.
mark
.
parametrize
(
"local_tp_size"
,
[
1
,
2
])
def
test_prefill_tp_size_greater_than_decode_tp_size
(
self
,
local_tp_size
:
int
,
dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations.
"""
vllm_config
=
create_vllm_config
()
local_tp_size
=
1
vllm_config
.
parallel_config
.
tensor_parallel_size
=
local_tp_size
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
worker
=
connector
.
connector_worker
# Minimal local registration params used by add_remote_agent
worker
.
slot_size_per_layer
=
[
4096
]
worker
.
block_len_per_layer
=
[
4096
*
worker
.
block_size
]
worker
.
num_blocks
=
1
worker
.
dst_num_blocks
[
worker
.
engine_id
]
=
worker
.
num_blocks
worker
.
src_blocks_data
=
[(
0
,
worker
.
block_len_per_layer
[
0
],
worker
.
tp_rank
)]
def
check_handshake
(
remote_tp_size
:
int
):
tp_ratio
=
remote_tp_size
//
local_tp_size
assert
set
(
remote_agents
.
keys
())
==
set
(
range
(
tp_ratio
))
remote_engine_id
=
worker
.
REMOTE_ENGINE_ID
assert
worker
.
_tp_size
[
remote_engine_id
]
==
remote_tp_size
assert
-
tp_ratio
==
worker
.
kv_topo
.
tp_ratio_from_engine_id
(
remote_engine_id
)
# ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks
assert
-
tp_ratio
in
worker
.
src_xfer_handles_by_tp_ratio
assert
len
(
worker
.
src_xfer_handles_by_tp_ratio
[
-
tp_ratio
])
==
tp_ratio
assert
remote_engine_id
in
worker
.
dst_xfer_side_handles
assert
set
(
worker
.
dst_xfer_side_handles
[
remote_engine_id
].
keys
())
==
set
(
range
(
tp_ratio
)
)
remote_agents
=
worker
.
_nixl_handshake
(
host
=
"localhost"
,
port
=
1234
,
remote_tp_size
=
2
,
expected_engine_id
=
worker
.
REMOTE_ENGINE_ID
,
)
check_handshake
(
2
)
# NOTE flexiblity: a second remote with higher number of ranks is
# discovered. This is not a scenario we actively support right now, but
# the connector allows it.
worker
.
REMOTE_ENGINE_ID
=
"remote_engine_2"
remote_agents
=
worker
.
_nixl_handshake
(
host
=
"localhost"
,
port
=
1234
,
remote_tp_size
=
6
,
expected_engine_id
=
worker
.
REMOTE_ENGINE_ID
,
)
check_handshake
(
6
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
,
)
@
pytest
.
mark
.
parametrize
(
"local_tp_size"
,
[
1
,
2
])
def
test_prefill_tp_size_greater_than_decode_tp_size_mla
(
self
,
local_tp_size
:
int
,
dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations for an MLA model.
"""
vllm_config
=
create_vllm_config
()
d_tp_size
=
1
p_tp_size
=
2
# Build two separate connectors/workers to emulate P TP=2 ranks.
conn_p0
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
conn_p1
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
conn_p0
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
conn_p0
.
engine_id
,
hand_shake_latency
=
0
)
conn_p1
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
conn_p1
.
engine_id
,
hand_shake_latency
=
0
)
# Force P world size to 2 for both workers and emulate distinct tp_ranks.
# Also enable MLA path so that expected_finished_count is updated.
for
rank
,
worker
in
enumerate
(
(
conn_p0
.
connector_worker
,
conn_p1
.
connector_worker
)
):
worker
.
world_size
=
p_tp_size
worker
.
kv_topo
.
remote_tp_size
=
{
worker
.
engine_id
:
p_tp_size
}
worker
.
tp_rank
=
rank
worker
.
use_mla
=
True
req_id
=
"req-ep-dp2-p0"
now
=
time
.
perf_counter
()
# Register a request on P that is waiting for consumers to read
# (both workers track it).
conn_p0
.
connector_worker
.
_reqs_to_send
[
req_id
]
=
now
+
10.0
conn_p0
.
connector_worker
.
_reqs_to_process
.
add
(
req_id
)
conn_p1
.
connector_worker
.
_reqs_to_send
[
req_id
]
=
now
+
10.0
conn_p1
.
connector_worker
.
_reqs_to_process
.
add
(
req_id
)
# Simulate a read notification coming from D with (tp=1, dp=2).
notif
=
f
"
{
req_id
}
:
{
d_tp_size
}
"
.
encode
()
# D0-0->P0 notif
conn_p0
.
connector_worker
.
nixl_wrapper
.
get_new_notifs
=
lambda
:
{
"agent"
:
[
notif
]
}
# type: ignore[method-assign]
conn_p1
.
connector_worker
.
nixl_wrapper
.
get_new_notifs
=
lambda
:
{
"agent"
:
[
notif
]
}
# type: ignore[method-assign]
# Trigger notification processing via get_finished().
done_sending0
,
_
=
conn_p0
.
get_finished
(
finished_req_ids
=
set
())
done_sending1
,
_
=
conn_p1
.
get_finished
(
finished_req_ids
=
set
())
assert
req_id
in
done_sending0
and
req_id
in
done_sending1
# E2E aggregation: ensure the aggregated output marks the request
# as finished using the connector's expected_finished_count.
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
aggregator
=
KVOutputAggregator
.
from_connector
(
conn_p0
,
world_size
=
2
)
out0
=
ModelRunnerOutput
(
req_ids
=
[
req_id
],
req_id_to_index
=
{
req_id
:
0
},
sampled_token_ids
=
[[
0
]],
logprobs
=
None
,
prompt_logprobs_dict
=
{},
pooler_output
=
[
None
],
kv_connector_output
=
KVConnectorOutput
(
finished_sending
=
done_sending0
,
finished_recving
=
None
,
),
)
out1
=
ModelRunnerOutput
(
req_ids
=
[
req_id
],
req_id_to_index
=
{
req_id
:
0
},
sampled_token_ids
=
[[
0
]],
logprobs
=
None
,
prompt_logprobs_dict
=
{},
pooler_output
=
[
None
],
kv_connector_output
=
KVConnectorOutput
(
finished_sending
=
done_sending1
,
finished_recving
=
None
,
),
)
aggregated
=
aggregator
.
aggregate
([
out0
,
out1
],
output_rank
=
0
)
assert
aggregated
.
kv_connector_output
is
not
None
assert
aggregated
.
kv_connector_output
.
finished_sending
==
{
req_id
}
# Producers cleaned up state for the finished request.
assert
req_id
not
in
conn_p0
.
connector_worker
.
_reqs_to_send
assert
req_id
not
in
conn_p0
.
connector_worker
.
_reqs_to_process
assert
req_id
not
in
conn_p1
.
connector_worker
.
_reqs_to_send
assert
req_id
not
in
conn_p1
.
connector_worker
.
_reqs_to_process
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
,
...
...
@@ -585,6 +779,9 @@ class TestNixlHandshake:
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
)
# Register (mocked) local xfer handler
# worker = connector.connector_worker
# worker.src_xfer_handles_by_block_size = {worker.block_size: 1}
metadata
=
NixlConnectorMetadata
()
total_reqs
=
5
for
i
in
range
(
total_reqs
):
...
...
@@ -672,7 +869,6 @@ class TestNixlHandshake:
with
pytest
.
raises
(
RuntimeError
):
# mismatched layout is expected to fail
worker
.
add_remote_agent
(
meta
,
remote_tp_size
=
2
)
with
pytest
.
raises
(
AssertionError
):
worker
.
add_remote_agent
(
meta
,
remote_tp_size
=
1
)
@
patch
(
...
...
@@ -1357,8 +1553,11 @@ def test_shutdown_cleans_up_resources(dist_init):
patch
.
object
(
nixl_wrapper
,
"deregister_memory"
)
as
mock_dereg
,
):
worker
.
_recving_transfers
=
{
"req1"
:
[
123
]}
worker
.
src_xfer_side_handle
=
456
worker
.
dst_xfer_side_handles
=
{
"engine1"
:
789
}
# Mock register_kv_cache which registers local handle
worker
.
src_xfer_handles_by_block_size
=
{
worker
.
block_size
:
455
}
# P TP = 2 * D TP case, we should register 2 local handles
worker
.
src_xfer_handles_by_tp_ratio
=
{
-
2
:
[
456
,
457
]}
worker
.
dst_xfer_side_handles
=
{
"engine1"
:
{
0
:
789
}}
worker
.
_remote_agents
=
{
"engine1"
:
{
0
:
"agent1"
}}
worker
.
_registered_descs
=
[
"desc1"
,
"desc2"
]
...
...
@@ -1379,8 +1578,10 @@ def test_shutdown_cleans_up_resources(dist_init):
mock_listener
.
join
.
assert_called_once
()
mock_rel_xfer
.
assert_called_once_with
(
123
)
assert
mock_rel_dlist
.
call_count
==
2
mock_rel_dlist
.
assert_any_call
(
456
)
# src handle
assert
mock_rel_dlist
.
call_count
==
4
mock_rel_dlist
.
assert_any_call
(
455
)
# src handle (whole region)
mock_rel_dlist
.
assert_any_call
(
456
)
# src handle (1st chunk)
mock_rel_dlist
.
assert_any_call
(
457
)
# src handle (2nd chunk)
mock_rel_dlist
.
assert_any_call
(
789
)
# dst handle
mock_rem_agent
.
assert_called_once_with
(
"agent1"
)
assert
mock_dereg
.
call_count
==
2
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
bc3700e0
...
...
@@ -21,6 +21,8 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
EngineId
=
str
def
get_kv_connector_cache_layout
():
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
...
...
@@ -209,12 +211,12 @@ class TpKVTopology:
"""
tp_rank
:
int
remote_tp_size
:
dict
[
str
,
int
]
remote_tp_size
:
dict
[
EngineId
,
int
]
is_mla
:
bool
total_num_kv_heads
:
int
attn_backend
:
type
[
AttentionBackend
]
engine_id
:
str
remote_block_size
:
dict
[
str
,
int
]
engine_id
:
EngineId
remote_block_size
:
dict
[
EngineId
,
int
]
def
__post_init__
(
self
):
# Figure out whether the first dimension of the cache is K/V
...
...
@@ -256,18 +258,28 @@ class TpKVTopology:
Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`.
groups of size `tp_ratio`.If remote tp_size > local tp_size, the
ratio is flipped (remote_size/local_size) and the returned value is
negative.
"""
assert
self
.
tp_size
%
remote_tp_size
==
0
,
(
f
"Local tensor parallel size
{
self
.
tp_size
}
is not divisible "
f
"by remote tensor parallel size
{
remote_tp_size
}
."
if
self
.
tp_size
>=
remote_tp_size
:
assert
self
.
tp_size
%
remote_tp_size
==
0
,
(
f
"Local tensor parallel size
{
self
.
tp_size
}
is not divisible "
f
"by remote tensor parallel size
{
remote_tp_size
}
."
)
return
self
.
tp_size
//
remote_tp_size
assert
remote_tp_size
%
self
.
tp_size
==
0
,
(
f
"Remote tensor parallel size
{
remote_tp_size
}
is not divisible "
f
"by local tensor parallel size
{
self
.
tp_size
}
."
)
return
self
.
tp_size
//
remote_tp_size
# P TP > D TP case, return the ratio as negative
return
-
remote_tp_size
//
self
.
tp_size
def
block_size_ratio
(
self
,
remote_block_size
:
int
,
)
->
floa
t
:
)
->
in
t
:
"""
Calculate the block size ratio between local and remote TP.
"""
...
...
@@ -279,19 +291,19 @@ class TpKVTopology:
def
tp_ratio_from_engine_id
(
self
,
remote_engine_id
:
str
,
remote_engine_id
:
EngineId
,
)
->
int
:
remote_tp_size
=
self
.
remote_tp_size
[
remote_engine_id
]
return
self
.
tp_ratio
(
remote_tp_size
)
def
block_size_ratio_from_engine_id
(
self
,
remote_engine_id
:
str
,
)
->
floa
t
:
remote_engine_id
:
EngineId
,
)
->
in
t
:
remote_block_size
=
self
.
remote_block_size
[
remote_engine_id
]
return
self
.
block_size_ratio
(
remote_block_size
)
def
is_kv_replicated
(
self
,
engine_id
:
str
)
->
bool
:
def
is_kv_replicated
(
self
,
engine_id
:
EngineId
)
->
bool
:
"""
Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads.
...
...
@@ -299,24 +311,30 @@ class TpKVTopology:
tp_size
=
self
.
remote_tp_size
[
engine_id
]
return
tp_size
//
self
.
total_num_kv_heads
>=
1
def
replicates_kv_cache
(
self
,
remote_engine_id
:
str
)
->
bool
:
def
replicates_kv_cache
(
self
,
remote_engine_id
:
EngineId
)
->
bool
:
# MLA is always replicated as the hidden dim can't be split.
return
self
.
is_mla
or
self
.
is_kv_replicated
(
remote_engine_id
)
def
get_target_remote_rank
(
def
get_target_remote_rank
s
(
self
,
remote_tp_size
:
int
,
)
->
int
:
)
->
list
[
int
]
:
"""
Get the remote TP rank (on P) that the current local TP rank
(on D) will read from.
(on D) will read from. When remote tp_size > local tp_size, we
read from multiple remote ranks.
"""
tp_ratio
=
self
.
tp_ratio
(
remote_tp_size
)
return
self
.
tp_rank
//
tp_ratio
if
tp_ratio
>
0
:
return
[
self
.
tp_rank
//
tp_ratio
]
def
get_target_remote_rank_from_engine_id
(
# P TP > D TP case, D reads from |tp_ratio| remote workers.
tp_ratio
=
-
tp_ratio
return
[
self
.
tp_rank
*
tp_ratio
+
i
for
i
in
range
(
tp_ratio
)]
def
get_target_remote_ranks_from_engine_id
(
self
,
remote_engine_id
:
str
,
)
->
int
:
remote_engine_id
:
EngineId
,
)
->
list
[
int
]
:
remote_tp_size
=
self
.
remote_tp_size
[
remote_engine_id
]
return
self
.
get_target_remote_rank
(
remote_tp_size
)
return
self
.
get_target_remote_rank
s
(
remote_tp_size
)
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
bc3700e0
...
...
@@ -23,7 +23,7 @@ from vllm import envs
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.selector
import
get_attn_backend
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.utils
import
TpKVTopology
from
vllm.distributed.kv_transfer.kv_connector.utils
import
EngineId
,
TpKVTopology
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
CopyBlocksOp
,
KVConnectorBase_V1
,
...
...
@@ -56,7 +56,6 @@ if TYPE_CHECKING:
from
vllm.v1.request
import
Request
TransferHandle
=
int
EngineId
=
str
ReqId
=
str
#
...
...
@@ -873,9 +872,10 @@ class NixlConnectorWorker:
self
.
copy_blocks
:
CopyBlocksOp
|
None
=
None
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
# rank will still only pull from a single remote TP worker.
self
.
kv_caches_base_addr
:
dict
[
EngineId
,
list
[
int
]]
=
{}
self
.
device_id
:
int
=
0
# Current rank may pull from multiple remote TP workers.
# EngineId, dict[int, list[int]] -> engine_id, tp_rank, base_addr_for_layer
self
.
kv_caches_base_addr
=
defaultdict
[
EngineId
,
dict
[
int
,
list
[
int
]]](
dict
)
# Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer)
...
...
@@ -883,10 +883,12 @@ class NixlConnectorWorker:
self
.
num_layers
=
0
# nixl_prepped_dlist_handle.
self
.
src_xfer_side_handle
:
int
=
0
self
.
src_xfer_side_handles
:
dict
[
int
,
int
]
=
{}
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
self
.
dst_xfer_side_handles
:
dict
[
EngineId
,
int
]
=
{}
self
.
src_xfer_handles_by_block_size
:
dict
[
int
,
int
]
=
{}
# Populated dynamically during handshake based on remote configuration.
# Keep track of regions at different tp_ratio values. tp_ratio->handles
self
.
src_xfer_handles_by_tp_ratio
:
dict
[
int
,
list
[
int
]]
=
{}
# Map of engine_id -> {tp_rank: nixl_prepped_dlist_handle (int)}.
self
.
dst_xfer_side_handles
=
defaultdict
[
EngineId
,
dict
[
int
,
int
]](
dict
)
# Map of engine_id -> num_blocks. All ranks in the same deployment will
# have the same number of blocks.
...
...
@@ -977,103 +979,108 @@ class NixlConnectorWorker:
expected_engine_id
:
str
,
)
->
dict
[
int
,
str
]:
"""Do a NIXL handshake with a remote instance."""
start_time
=
time
.
perf_counter
()
# NOTE(rob): we need each rank to have a unique port. This is
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
# Handshake only with the remote TP rank that current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
p_remote_rank
=
self
.
kv_topo
.
get_target_remote_rank
(
remote_tp_size
)
# When target instance TP > local TP, we need to perform multiple
# handshakes. Do it in a single background job for simplicity.
# Regardless, only handshake with the remote TP rank(s) that current
# local rank will read from. Note that With homogeneous TP,
# this happens to be the same single rank_i.
p_remote_ranks
=
self
.
kv_topo
.
get_target_remote_ranks
(
remote_tp_size
)
remote_rank_to_agent_name
=
{}
path
=
make_zmq_path
(
"tcp"
,
host
,
port
)
logger
.
debug
(
"Querying metadata on path: %s at remote tp rank %s"
,
path
,
p_remote_rank
)
# Send query for the request.
with
zmq_ctx
(
zmq
.
REQ
,
path
)
as
sock
:
msg
=
msgspec
.
msgpack
.
encode
((
GET_META_MSG
,
p_remote_rank
))
# Set receive timeout to 5 seconds to avoid hanging on dead server
sock
.
setsockopt
(
zmq
.
RCVTIMEO
,
5000
)
# milliseconds
sock
.
send
(
msg
)
handshake_bytes
=
sock
.
recv
()
# Decode handshake payload to get compatibility hash
handshake_decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlHandshakePayload
)
try
:
handshake_payload
=
handshake_decoder
.
decode
(
handshake_bytes
)
except
(
msgspec
.
DecodeError
,
msgspec
.
ValidationError
)
as
e
:
raise
RuntimeError
(
f
"Failed to decode NixlHandshakePayload. This likely indicates "
f
"an incompatibility between connector version. Error:
{
e
}
"
)
from
e
got_metadata_time
=
time
.
perf_counter
()
logger
.
debug
(
"NIXL handshake: get metadata took: %s"
,
got_metadata_time
-
start_time
)
# Check compatibility hash BEFORE decoding agent metadata
if
(
self
.
enforce_compat_hash
and
handshake_payload
.
compatibility_hash
!=
self
.
compat_hash
):
raise
RuntimeError
(
f
"NIXL compatibility hash mismatch. "
f
"Local:
{
self
.
compat_hash
}
, "
f
"Remote:
{
handshake_payload
.
compatibility_hash
}
. "
f
"Prefill and decode instances have incompatible configurations. "
f
"This may be due to: different vLLM versions, models, dtypes, "
f
"KV cache layouts, attention backends, etc. "
f
"Both instances must use identical configurations."
f
"Disable this check using "
f
'--kv-transfer-config
\'
{{"kv_connector_extra_config": '
f
'{{"enforce_handshake_compat": false}}}}
\'
'
for
remote_rank
in
p_remote_ranks
:
logger
.
debug
(
"Querying metadata on path: %s at remote tp rank %s"
,
path
,
remote_rank
,
)
logger
.
info
(
"NIXL compatibility check passed (hash: %s)"
,
handshake_payload
.
compatibility_hash
,
)
start_time
=
time
.
perf_counter
()
# Send query for the request.
msg
=
msgspec
.
msgpack
.
encode
((
GET_META_MSG
,
remote_rank
))
# Set receive timeout to 5 seconds to avoid hanging on dead server
sock
.
setsockopt
(
zmq
.
RCVTIMEO
,
5000
)
# milliseconds
sock
.
send
(
msg
)
handshake_bytes
=
sock
.
recv
()
# Decode agent metadata
metadata_decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlAgentMetadata
)
try
:
metadata
=
metadata_decoder
.
decode
(
handshake_payload
.
agent_metadata_bytes
# Decode handshake payload to get compatibility hash
handshake_decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlHandshakePayload
)
try
:
handshake_payload
=
handshake_decoder
.
decode
(
handshake_bytes
)
except
(
msgspec
.
DecodeError
,
msgspec
.
ValidationError
)
as
e
:
raise
RuntimeError
(
f
"Failed to decode NixlHandshakePayload. This likely indicates "
f
"an incompatibility between connector version. Error:
{
e
}
"
)
from
e
got_metadata_time
=
time
.
perf_counter
()
logger
.
debug
(
"NIXL handshake: get metadata took: %s"
,
got_metadata_time
-
start_time
,
)
except
(
msgspec
.
DecodeError
,
msgspec
.
ValidationError
)
as
e
:
# This should not happen if hash matched
raise
RuntimeError
(
f
"Failed to decode NixlAgentMetadata. Error:
{
e
}
"
)
from
e
# Ensure engine id matches.
if
metadata
.
engine_id
!=
expected_engine_id
:
raise
RuntimeError
(
f
"Remote NIXL agent engine ID mismatch. "
f
"Expected
{
expected_engine_id
}
,"
f
"received
{
metadata
.
engine_id
}
."
)
# Check compatibility hash BEFORE decoding agent metadata
if
(
self
.
enforce_compat_hash
and
handshake_payload
.
compatibility_hash
!=
self
.
compat_hash
):
raise
RuntimeError
(
f
"NIXL compatibility hash mismatch. "
f
"Local:
{
self
.
compat_hash
}
, "
f
"Remote:
{
handshake_payload
.
compatibility_hash
}
. "
f
"Prefill and decode instances have incompatible "
f
"configurations. This may be due to: different vLLM versions,"
f
" models, dtypes, KV cache layouts, attention backends, etc. "
f
"Both instances must use identical configurations."
f
"Disable this check using "
f
'--kv-transfer-config
\'
{{"kv_connector_extra_config": '
f
'{{"enforce_handshake_compat": false}}}}
\'
'
)
# Register Remote agent.
assert
metadata
.
block_size
<=
self
.
block_size
,
(
"nP > nD is not supported yet."
)
remote_agent_name
=
self
.
add_remote_agent
(
metadata
,
p_remote_rank
,
remote_tp_size
)
logger
.
info
(
"NIXL compatibility check passed (hash: %s)"
,
handshake_payload
.
compatibility_hash
,
)
setup_agent_time
=
time
.
perf_counter
()
logger
.
debug
(
"NIXL handshake: add agent took: %s"
,
setup_agent_time
-
got_metadata_time
,
)
# Decode agent metadata
metadata_decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlAgentMetadata
)
try
:
metadata
=
metadata_decoder
.
decode
(
handshake_payload
.
agent_metadata_bytes
)
except
(
msgspec
.
DecodeError
,
msgspec
.
ValidationError
)
as
e
:
# This should not happen if hash matched
raise
RuntimeError
(
f
"Failed to decode NixlAgentMetadata. Error:
{
e
}
"
)
from
e
# Ensure engine id matches.
if
metadata
.
engine_id
!=
expected_engine_id
:
raise
RuntimeError
(
f
"Remote NIXL agent engine ID mismatch. "
f
"Expected
{
expected_engine_id
}
,"
f
"received
{
metadata
.
engine_id
}
."
)
# Ensure engine id matches.
if
metadata
.
engine_id
!=
expected_engine_id
:
raise
RuntimeError
(
f
"Remote NIXL agent engine ID mismatch. "
f
"Expected
{
expected_engine_id
}
,"
f
"received
{
metadata
.
engine_id
}
."
)
setup_agent_time
=
time
.
perf_counter
()
# Remote rank -> agent name.
return
{
p_remote_rank
:
remote_agent_name
}
# Register Remote agent.
remote_agent_name
=
self
.
add_remote_agent
(
metadata
,
remote_rank
,
remote_tp_size
)
logger
.
debug
(
"NIXL handshake: add agent took: %s"
,
setup_agent_time
-
got_metadata_time
,
)
remote_rank_to_agent_name
[
remote_rank
]
=
remote_agent_name
return
remote_rank_to_agent_name
def
initialize_host_xfer_buffer
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
])
->
None
:
"""
...
...
@@ -1283,7 +1290,7 @@ class NixlConnectorWorker:
assert
len
(
self
.
block_len_per_layer
)
==
len
(
seen_base_addresses
)
assert
self
.
num_blocks
!=
0
self
.
kv_caches_base_addr
[
self
.
engine_id
]
=
seen_base_addresses
self
.
kv_caches_base_addr
[
self
.
engine_id
]
[
self
.
tp_rank
]
=
seen_base_addresses
self
.
num_regions
=
len
(
caches_data
)
self
.
num_layers
=
len
(
xfer_buffers
.
keys
())
...
...
@@ -1310,9 +1317,9 @@ class NixlConnectorWorker:
# Register local/src descr for NIXL xfer.
self
.
seen_base_addresses
=
seen_base_addresses
self
.
src_xfer_
side_
handle
=
self
.
register_local_xfer_handler
(
self
.
block_size
)
self
.
src_xfer_side_handles
[
self
.
block_size
]
=
self
.
src_xfer_side_handle
self
.
src_xfer_handle
s_by_block_size
[
self
.
block_size
],
self
.
src_blocks_data
=
(
self
.
register_local_xfer_handler
(
self
.
block_size
)
)
# TODO(mgoin): Hybrid memory allocator is currently disabled for
# models with local attention (Llama 4). Can remove this once enabled.
...
...
@@ -1340,8 +1347,8 @@ class NixlConnectorWorker:
agent_metadata
=
NixlAgentMetadata
(
engine_id
=
self
.
engine_id
,
agent_metadata
=
self
.
nixl_wrapper
.
get_agent_metadata
(),
kv_caches_base_addr
=
self
.
kv_caches_base_addr
[
self
.
engine_id
],
device_id
=
self
.
device_id
,
kv_caches_base_addr
=
self
.
kv_caches_base_addr
[
self
.
engine_id
][
self
.
tp_rank
],
num_blocks
=
self
.
num_blocks
,
block_lens
=
self
.
block_len_per_layer
,
kv_cache_layout
=
self
.
kv_cache_layout
...
...
@@ -1359,7 +1366,7 @@ class NixlConnectorWorker:
def
register_local_xfer_handler
(
self
,
block_size
:
int
,
)
->
int
:
)
->
tuple
[
int
,
list
[
tuple
[
int
,
int
,
int
]]]
:
"""
Function used for register local xfer handler with local block_size or
Remote block_size.
...
...
@@ -1407,7 +1414,7 @@ class NixlConnectorWorker:
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
# NIXL_INIT_AGENT to be used for preparations of local descs.
return
self
.
nixl_wrapper
.
prep_xfer_dlist
(
"NIXL_INIT_AGENT"
,
descs
)
return
self
.
nixl_wrapper
.
prep_xfer_dlist
(
"NIXL_INIT_AGENT"
,
descs
)
,
blocks_data
def
add_remote_agent
(
self
,
...
...
@@ -1421,10 +1428,12 @@ class NixlConnectorWorker:
In particular, handle both homogeneous and heterogeneous TP. The former
requires local rank_i to read from remote rank_i.
The latter, assuming D.world_size > P.world_size, requires that two or
more local TP worker share the xfer from a single TP worker.
The latter, in the case of D.world_size < P.world_size, requires that a
local (D) TP worker reads from multiple remote (P) TP workers.
Conversely, assuming D.world_size > P.world_size, two or more local TP
workers will read from a single remote TP worker.
Here's an example (non-MLA
case
):
Here's an example
for the last case described above
(non-MLA):
rank_offset p_remote_tp_rank
(kv split no)
...
...
@@ -1474,9 +1483,6 @@ class NixlConnectorWorker:
nixl_agent_meta
.
agent_metadata
)
# Handle tp_size>num_kv_heads: replicate KV cache.
replicates_kv_cache
=
self
.
kv_topo
.
replicates_kv_cache
(
engine_id
)
# Create dst descs and xfer side handles. TP workers have same #blocks
# so we only register once per engine_id.
# Example:
...
...
@@ -1490,14 +1496,52 @@ class NixlConnectorWorker:
self
.
dst_num_blocks
[
engine_id
]
=
nixl_agent_meta
.
num_blocks
# Keep track of remote agent kv caches base addresses.
self
.
kv_caches_base_addr
[
engine_id
]
=
nixl_agent_meta
.
kv_caches_base_addr
self
.
kv_caches_base_addr
[
engine_id
][
remote_tp_rank
]
=
(
nixl_agent_meta
.
kv_caches_base_addr
)
self
.
_validate_remote_agent_handshake
(
nixl_agent_meta
,
remote_tp_size
)
#
Number of D TP workers reading from a single P TP worker. This is
#
1 when P and D `--tensor-parallel-size` match
.
#
This is 1 when P and D `--tensor-parallel-size` match. Otherwise,
#
this is the ratio between the two sizes
.
tp_ratio
=
self
.
kv_topo
.
tp_ratio_from_engine_id
(
engine_id
)
# Handle tp_size>num_kv_heads: replicate KV cache.
indexes_into_remote
=
(
not
self
.
kv_topo
.
replicates_kv_cache
(
engine_id
)
and
tp_ratio
>
0
)
logger
.
debug
(
"Registering remote agent (%s, rank %s) memory regions with tp_ratio %s"
,
engine_id
,
remote_tp_rank
,
tp_ratio
,
)
### (Optional) Register local agent memory regions. MLA is not split.
if
(
tp_ratio
<
0
and
not
self
.
use_mla
and
tp_ratio
not
in
self
.
src_xfer_handles_by_tp_ratio
):
# Remote tp_size > local tp_size: read from multiple remote ranks.
# Logically "split" own regions into |tp_ratio| chunks. Mind that
# we only do this once per remote tp_size (replica-friendly).
self
.
src_xfer_handles_by_tp_ratio
[
tp_ratio
]
=
[]
for
i
in
range
(
-
tp_ratio
):
blocks_data
=
[]
for
memory_region
in
self
.
src_blocks_data
:
addr
,
local_block_len
,
own_tp_rank
=
memory_region
# Computing block len layer by layer allows for different
# block sizes to be used.
remote_block_len
=
local_block_len
//
(
-
tp_ratio
)
addr
=
addr
+
i
*
remote_block_len
blocks_data
.
append
((
addr
,
remote_block_len
,
own_tp_rank
))
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
handle
=
self
.
nixl_wrapper
.
prep_xfer_dlist
(
"NIXL_INIT_AGENT"
,
descs
)
self
.
src_xfer_handles_by_tp_ratio
[
tp_ratio
].
append
(
handle
)
### Register remote agent memory regions
blocks_data
=
[]
# With homogeneous TP, D pulls the whole kv cache from corresponding
...
...
@@ -1507,14 +1551,19 @@ class NixlConnectorWorker:
# Register all remote blocks, but only the corresponding kv heads.
for
i
,
base_addr
in
enumerate
(
nixl_agent_meta
.
kv_caches_base_addr
):
kv_block_len
=
self
.
get_backend_aware_kv_block_len
(
layer_idx
=
i
)
remote_kv_block_len
=
kv_block_len
//
block_size_ratio
# Read our whole local region size from remote.
local_block_len
=
self
.
get_backend_aware_kv_block_len
(
layer_idx
=
i
)
remote_kv_block_len
=
local_block_len
//
block_size_ratio
if
block_size_ratio
>
1
:
# using remote kv_block_len as transfer unit
kv_block_len
=
remote_kv_block_len
local_block_len
=
remote_kv_block_len
if
tp_ratio
<
0
and
not
self
.
use_mla
:
# Remote tp is bigger: read a chunk of local region from remote
local_block_len
=
local_block_len
//
(
-
tp_ratio
)
rank_offset
=
(
self
.
tp_rank
%
tp_ratio
*
remote_kv_block_len
if
not
replicates_kv_cach
e
if
indexes_into_remot
e
else
0
)
for
block_id
in
range
(
nixl_agent_meta
.
num_blocks
):
...
...
@@ -1524,7 +1573,7 @@ class NixlConnectorWorker:
# self.block_len == remote_block_len//tp_ratio bytes.
addr
=
base_addr
+
block_offset
+
rank_offset
# (addr, len, device id)
blocks_data
.
append
((
addr
,
kv
_block_len
,
nixl_agent_meta
.
device_id
))
blocks_data
.
append
((
addr
,
local
_block_len
,
nixl_agent_meta
.
device_id
))
if
self
.
kv_topo
.
is_kv_layout_blocks_first
:
# With FlashInfer index V separately to allow head splitting.
...
...
@@ -1533,7 +1582,7 @@ class NixlConnectorWorker:
addr
=
base_addr
+
block_offset
+
rank_offset
v_addr
=
addr
+
nixl_agent_meta
.
block_lens
[
i
]
//
2
blocks_data
.
append
(
(
v_addr
,
kv
_block_len
,
nixl_agent_meta
.
device_id
)
(
v_addr
,
local
_block_len
,
nixl_agent_meta
.
device_id
)
)
logger
.
debug
(
...
...
@@ -1546,15 +1595,15 @@ class NixlConnectorWorker:
# Register with NIXL.
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
self
.
dst_xfer_side_handles
[
engine_id
]
=
self
.
nixl_wrapper
.
prep_xfer_dlist
(
remote_agent_name
,
descs
self
.
dst_xfer_side_handles
[
engine_id
]
[
remote_tp_rank
]
=
(
self
.
nixl_wrapper
.
prep_xfer_dlist
(
remote_agent_name
,
descs
)
)
if
block_size_ratio
>
1
:
# when prefill with smaller block_size, we need to init a
# new handler with same block_len to match
self
.
src_xfer_
side_
handles
[
nixl_agent_meta
.
block_size
]
=
(
self
.
register_local_xfer_handler
(
nixl_agent_meta
.
block_size
)
self
.
src_xfer_handles
_by_block_size
[
nixl_agent_meta
.
block_size
]
=
(
self
.
register_local_xfer_handler
(
nixl_agent_meta
.
block_size
)
[
0
]
)
return
remote_agent_name
...
...
@@ -1574,7 +1623,9 @@ class NixlConnectorWorker:
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
remote_engine_id
)
assert
tp_ratio
>
0
,
"Decode TP cannot be smaller than prefill TP"
# Num kv_heads > tp_size and P TP > D TP case, not supported
assert
not
(
tp_ratio
<
0
and
self
.
kv_topo
.
is_kv_replicated
(
remote_engine_id
))
assert
not
self
.
_use_pallas
or
tp_ratio
==
1
,
(
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
)
...
...
@@ -1616,17 +1667,29 @@ class NixlConnectorWorker:
"All remote layers must have the same block size"
)
assert
(
remote_block_len
==
(
self
.
block_len_per_layer
[
0
]
*
tp_ratio
)
//
block_size_ratio
),
(
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
)
if
tp_ratio
>
0
:
# Remote tp is smaller: remote block_len size is bigger
assert
(
remote_block_len
==
(
self
.
block_len_per_layer
[
0
]
*
tp_ratio
)
//
block_size_ratio
),
(
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, page_size, head_dim] and same dtype."
)
# noqa: E501
else
:
assert
block_size_ratio
==
1
,
(
"Different local/remote block sizes are not supported when"
" P TP > D TP."
)
# Remote tp is bigger: remote block_len size is smaller
assert
remote_block_len
==
self
.
block_len_per_layer
[
0
]
//
(
-
tp_ratio
),
(
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads/tp_ratio, page_size, head_dim] and same dtype."
)
# noqa: E501
# TP workers have same #blocks.
# TP workers
that handhshake with same remote
have same #blocks.
assert
self
.
dst_num_blocks
[
remote_engine_id
]
==
nixl_agent_meta
.
num_blocks
# Same number of regions/~layers.
assert
len
(
nixl_agent_meta
.
kv_caches_base_addr
)
==
len
(
self
.
block_len_per_layer
)
def
sync_recved_kv_to_device
(
self
,
req_id
:
str
,
meta
:
ReqMeta
):
...
...
@@ -1710,7 +1773,7 @@ class NixlConnectorWorker:
)
cache
.
index_copy_
(
0
,
indices
,
permuted_blocks
)
def
blocksize_post_process
(
self
,
block_ids_per_ratio
:
dict
[
floa
t
,
list
[
list
[
int
]]]):
def
blocksize_post_process
(
self
,
block_ids_per_ratio
:
dict
[
in
t
,
list
[
list
[
int
]]]):
def
_process_local_gt_remote
(
blocks_to_update
,
block_size_ratio
):
n_kv_heads
,
block_size
,
head_size
=
blocks_to_update
.
shape
[
1
:]
remote_block_size
=
block_size
//
block_size_ratio
...
...
@@ -1840,7 +1903,7 @@ class NixlConnectorWorker:
notified_req_ids
:
set
[
str
]
=
set
()
for
notifs
in
self
.
nixl_wrapper
.
get_new_notifs
().
values
():
for
notif
in
notifs
:
req_id
,
tp_
ratio
=
notif
.
decode
(
"utf-8"
).
rsplit
(
":"
,
1
)
req_id
,
tp_
size
=
notif
.
decode
(
"utf-8"
).
rsplit
(
":"
,
1
)
if
(
req_id
not
in
self
.
_reqs_to_send
and
req_id
not
in
self
.
_reqs_to_process
...
...
@@ -1853,9 +1916,22 @@ class NixlConnectorWorker:
)
continue
# NOTE: `tp_ratio` is the opposite when swapping local<>remote
n_consumers
=
int
(
tp_size
)
tp_ratio
=
self
.
kv_topo
.
tp_ratio
(
n_consumers
)
# Number of reads *per producer* to wait for.
# When remote D TP > local P TP we expect `tp_ratio` reads.
consumers_per_producer
=
(
-
tp_ratio
if
n_consumers
>
self
.
world_size
else
1
)
self
.
consumer_notification_counts_by_req
[
req_id
]
+=
1
# Wait all consumers (D) to be done reading before freeing.
if
self
.
consumer_notification_counts_by_req
[
req_id
]
==
int
(
tp_ratio
):
if
(
self
.
consumer_notification_counts_by_req
[
req_id
]
==
consumers_per_producer
):
notified_req_ids
.
add
(
req_id
)
del
self
.
consumer_notification_counts_by_req
[
req_id
]
self
.
_reqs_to_process
.
remove
(
req_id
)
...
...
@@ -1872,7 +1948,7 @@ class NixlConnectorWorker:
"""
done_req_ids
:
set
[
str
]
=
set
()
for
req_id
,
handles
in
list
(
transfers
.
items
()):
in_progress
=
False
in_progress
=
[]
for
handle
in
handles
:
try
:
xfer_state
=
self
.
nixl_wrapper
.
check_xfer_state
(
handle
)
...
...
@@ -1882,7 +1958,7 @@ class NixlConnectorWorker:
self
.
xfer_stats
.
record_transfer
(
res
)
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
elif
xfer_state
==
"PROC"
:
in_progress
=
True
in_progress
.
append
(
handle
)
continue
else
:
logger
.
error
(
...
...
@@ -1892,7 +1968,6 @@ class NixlConnectorWorker:
xfer_state
,
)
self
.
_handle_failed_transfer
(
req_id
,
handle
)
in_progress
=
False
except
Exception
:
logger
.
exception
(
"NIXL transfer exception for request %s. "
...
...
@@ -1900,11 +1975,13 @@ class NixlConnectorWorker:
req_id
,
)
self
.
_handle_failed_transfer
(
req_id
,
handle
)
in_progress
=
False
if
not
in_progress
:
# Only report request as completed when all transfers are done.
done_req_ids
.
add
(
req_id
)
del
transfers
[
req_id
]
else
:
transfers
[
req_id
]
=
in_progress
return
done_req_ids
def
_handle_failed_transfer
(
self
,
req_id
:
str
,
handle
:
int
):
...
...
@@ -1982,18 +2059,62 @@ class NixlConnectorWorker:
def
_read_blocks_for_req
(
self
,
req_id
:
str
,
meta
:
ReqMeta
):
assert
meta
.
remote
is
not
None
logger
.
debug
(
"Remote agent %s available, calling _read_blocks for req %s"
,
meta
.
remote
.
engine_id
,
req_id
,
)
self
.
_read_blocks
(
request_id
=
req_id
,
dst_engine_id
=
meta
.
remote
.
engine_id
,
remote_request_id
=
meta
.
remote
.
request_id
,
local_block_ids
=
meta
.
local_physical_block_ids
,
remote_block_ids
=
meta
.
remote
.
block_ids
,
remote_ranks
=
self
.
kv_topo
.
get_target_remote_ranks_from_engine_id
(
meta
.
remote
.
engine_id
)
tp_ratio
=
self
.
kv_topo
.
tp_ratio_from_engine_id
(
meta
.
remote
.
engine_id
)
# D may have to perform multiple reads from different remote ranks.
for
i
,
remote_rank
in
enumerate
(
remote_ranks
):
if
self
.
use_mla
and
tp_ratio
<
0
and
i
>
0
:
# MLA opt: when P TP > D TP, only a single read is executed for
# the first remote rank (cache is duplicated)..
break
remote_block_size
=
self
.
kv_topo
.
remote_block_size
[
meta
.
remote
.
engine_id
]
logger
.
debug
(
"Remote agent %s available, calling _read_blocks"
" on remote rank %s with remote block size %s for req %s"
,
meta
.
remote
.
engine_id
,
remote_rank
,
remote_block_size
,
req_id
,
)
# Get side handles.
if
tp_ratio
<
0
and
not
self
.
use_mla
:
assert
remote_block_size
==
self
.
block_size
# Remote tp_size > local tp_size: we must perform multiple
# reads. Get the memory chunk onto which we will write to.
local_xfer_side_handle
=
self
.
src_xfer_handles_by_tp_ratio
[
tp_ratio
][
i
]
else
:
# Single read from remote, we write to the whole memory region.
# Also handle remote block size different from local block size.
local_xfer_side_handle
=
self
.
src_xfer_handles_by_block_size
[
remote_block_size
]
# Destination handle: remote_engine_id -> remote_rank -> handle.
remote_xfer_side_handle
=
self
.
dst_xfer_side_handles
[
meta
.
remote
.
engine_id
][
remote_rank
]
self
.
_read_blocks
(
request_id
=
req_id
,
dst_engine_id
=
meta
.
remote
.
engine_id
,
remote_request_id
=
meta
.
remote
.
request_id
,
local_block_ids
=
meta
.
local_physical_block_ids
,
remote_block_ids
=
meta
.
remote
.
block_ids
,
remote_rank
=
remote_rank
,
local_xfer_side_handle
=
local_xfer_side_handle
,
remote_xfer_side_handle
=
remote_xfer_side_handle
,
)
if
self
.
use_mla
and
tp_ratio
<
0
:
# ..but we still need to notify the other remote ranks that we
# have the blocks we need so they can update the request state.
notif_id
=
f
"
{
req_id
}
:
{
self
.
world_size
}
"
.
encode
()
remote_agents
=
self
.
_remote_agents
[
meta
.
remote
.
engine_id
]
for
rank_to_notify
,
agent
in
remote_agents
.
items
():
if
rank_to_notify
!=
remote_rank
:
self
.
nixl_wrapper
.
send_notif
(
agent
,
notif_msg
=
notif_id
)
def
_read_blocks
(
self
,
...
...
@@ -2002,7 +2123,14 @@ class NixlConnectorWorker:
dst_engine_id
:
str
,
request_id
:
str
,
remote_request_id
:
str
,
remote_rank
:
int
,
local_xfer_side_handle
:
int
,
remote_xfer_side_handle
:
int
,
):
"""
Post a READ point-to-point xfer request from a single local worker to
a single remote worker.
"""
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
dst_engine_id
)
if
block_size_ratio
>
1
:
local_block_ids
=
self
.
get_mapped_blocks
(
...
...
@@ -2031,18 +2159,14 @@ class NixlConnectorWorker:
# saturate IB with heterogeneous TP sizes. We should remove the staging
# blocks until we are ready.
# Number of D TP workers that will read from dst P. Propagate
tp_rati
o
# Number of D TP workers that will read from dst P. Propagate
inf
o
# on notification so that dst worker can wait before freeing blocks.
tp_ratio
=
self
.
kv_topo
.
tp_ratio_from_engine_id
(
dst_engine_id
)
notif_id
=
f
"
{
remote_request_id
}
:
{
tp_ratio
}
"
.
encode
()
notif_id
=
f
"
{
remote_request_id
}
:
{
self
.
world_size
}
"
.
encode
()
# Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need.
num_local_blocks
=
len
(
local_block_ids
)
if
num_local_blocks
==
0
:
remote_rank
=
self
.
kv_topo
.
get_target_remote_rank_from_engine_id
(
dst_engine_id
)
agent_name
=
self
.
_remote_agents
[
dst_engine_id
][
remote_rank
]
try
:
self
.
nixl_wrapper
.
send_notif
(
agent_name
,
notif_msg
=
notif_id
)
...
...
@@ -2062,13 +2186,6 @@ class NixlConnectorWorker:
if
num_local_blocks
<
num_remote_blocks
:
remote_block_ids
=
remote_block_ids
[
-
num_local_blocks
:]
# Get side handles.
remote_block_size
=
self
.
kv_topo
.
remote_block_size
[
dst_engine_id
]
local_xfer_side_handle
=
self
.
src_xfer_side_handles
.
get
(
remote_block_size
,
self
.
src_xfer_side_handle
)
remote_xfer_side_handle
=
self
.
dst_xfer_side_handles
[
dst_engine_id
]
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
# workers will issue xfers to parts of the P worker remote kv caches.
...
...
@@ -2230,7 +2347,7 @@ class NixlConnectorWorker:
block_ids_np
,
self
.
_physical_blocks_per_logical_kv_block
,
block_arange
).
tolist
()
def
get_backend_aware_kv_block_len
(
self
,
layer_idx
:
int
):
def
get_backend_aware_kv_block_len
(
self
,
layer_idx
:
int
)
->
int
:
"""
Get the block length for one K/V element (K and V have the same size).
...
...
@@ -2276,11 +2393,16 @@ class NixlConnectorWorker:
for
handle
in
handles
:
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
self
.
_recving_transfers
.
clear
()
if
self
.
src_xfer_side_handle
:
self
.
nixl_wrapper
.
release_dlist_handle
(
self
.
src_xfer_side_handle
)
self
.
src_xfer_side_handle
=
0
for
dst_xfer_side_handle
in
self
.
dst_xfer_side_handles
.
values
():
self
.
nixl_wrapper
.
release_dlist_handle
(
dst_xfer_side_handle
)
for
handle
in
self
.
src_xfer_handles_by_block_size
.
values
():
self
.
nixl_wrapper
.
release_dlist_handle
(
handle
)
self
.
src_xfer_handles_by_block_size
.
clear
()
for
handles
in
self
.
src_xfer_handles_by_tp_ratio
.
values
():
for
handle
in
handles
:
self
.
nixl_wrapper
.
release_dlist_handle
(
handle
)
self
.
src_xfer_handles_by_tp_ratio
.
clear
()
for
dst_xfer_side_handles
in
self
.
dst_xfer_side_handles
.
values
():
for
dst_xfer_side_handle
in
dst_xfer_side_handles
.
values
():
self
.
nixl_wrapper
.
release_dlist_handle
(
dst_xfer_side_handle
)
self
.
dst_xfer_side_handles
.
clear
()
for
remote_agents
in
self
.
_remote_agents
.
values
():
for
agent_name
in
remote_agents
.
values
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment