Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bc3700e0
Unverified
Commit
bc3700e0
authored
Dec 18, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Dec 18, 2025
Browse files
[NIXL] Support P tensor-parallel-size > D tensor-parallel-size (#27274)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
fd8afdf3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
556 additions
and
212 deletions
+556
-212
tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh
...nnector/nixl_integration/tp_config_sweep_accuracy_test.sh
+3
-0
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+223
-22
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+40
-22
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+290
-168
No files found.
tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh
View file @
bc3700e0
...
@@ -8,9 +8,12 @@ SCRIPT="v1/kv_connector/nixl_integration/run_accuracy_test.sh"
...
@@ -8,9 +8,12 @@ SCRIPT="v1/kv_connector/nixl_integration/run_accuracy_test.sh"
configs
=(
configs
=(
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2"
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2"
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2"
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2"
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1"
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA case
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA case
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA+P-TP1, D-DPEP=2 (TP=1)
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA+P-TP1, D-DPEP=2 (TP=1)
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
# MLA+P-TP2, D-DPEP=2 (TP=1)
)
)
run_tests
()
{
run_tests
()
{
...
...
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
bc3700e0
...
@@ -391,6 +391,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
...
@@ -391,6 +391,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_hand_shake_latency
=
hand_shake_latency
self
.
_hand_shake_latency
=
hand_shake_latency
self
.
kv_cache_layout
=
kv_cache_layout
self
.
kv_cache_layout
=
kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
self
.
src_xfer_handles_by_block_size
=
{
self
.
block_size
:
1
}
def
_nixl_handshake
(
def
_nixl_handshake
(
self
,
host
:
str
,
port
:
int
,
remote_tp_size
:
int
,
expected_engine_id
:
str
self
,
host
:
str
,
port
:
int
,
remote_tp_size
:
int
,
expected_engine_id
:
str
...
@@ -407,22 +409,43 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
...
@@ -407,22 +409,43 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
assert
expected_engine_id
==
self
.
REMOTE_ENGINE_ID
assert
expected_engine_id
==
self
.
REMOTE_ENGINE_ID
remote_agent_name
=
self
.
add_remote_agent
(
# Adjust remote block length metadata to satisfy heterogeneous TP
NixlAgentMetadata
(
# invariants enforced during handshake validation.
engine_id
=
self
.
REMOTE_ENGINE_ID
,
remote_block_lens
=
list
(
self
.
block_len_per_layer
)
agent_metadata
=
FakeNixlWrapper
.
AGENT_METADATA
,
tp_ratio
=
self
.
kv_topo
.
tp_ratio
(
remote_tp_size
)
kv_caches_base_addr
=
[
0
],
if
remote_tp_size
>
self
.
world_size
:
device_id
=
0
,
# P TP > D TP case, block_len of remote is smaller
num_blocks
=
1
,
remote_block_lens
=
[
block_lens
=
self
.
block_len_per_layer
,
block_len
//
(
-
tp_ratio
)
for
block_len
in
remote_block_lens
# `self.kv_cache_layout` is only forced to HND when vllm engine
]
# is started. We mock HND here.
elif
remote_tp_size
<
self
.
world_size
:
kv_cache_layout
=
"HND"
,
remote_block_lens
=
[
block_size
=
self
.
block_size
,
block_len
*
tp_ratio
for
block_len
in
remote_block_lens
),
]
remote_tp_size
=
remote_tp_size
,
)
# When remote tp_size > local tp_size, handshake with multiple
return
{
0
:
remote_agent_name
}
# remote ranks.
num_hanshakes
=
1
if
tp_ratio
>
0
else
-
tp_ratio
remote_agents
:
dict
[
int
,
str
]
=
{}
for
remote_tp_rank
in
range
(
num_hanshakes
):
remote_agent_name
=
self
.
add_remote_agent
(
NixlAgentMetadata
(
engine_id
=
self
.
REMOTE_ENGINE_ID
,
agent_metadata
=
FakeNixlWrapper
.
AGENT_METADATA
,
kv_caches_base_addr
=
[
0
],
device_id
=
remote_tp_rank
,
num_blocks
=
1
,
block_lens
=
remote_block_lens
,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout
=
"HND"
,
block_size
=
self
.
block_size
,
),
remote_tp_rank
=
remote_tp_rank
,
remote_tp_size
=
remote_tp_size
,
)
remote_agents
[
remote_tp_rank
]
=
remote_agent_name
return
remote_agents
class
TestNixlHandshake
:
class
TestNixlHandshake
:
...
@@ -453,7 +476,13 @@ class TestNixlHandshake:
...
@@ -453,7 +476,13 @@ class TestNixlHandshake:
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
)
assert
isinstance
(
connector
.
connector_worker
.
nixl_wrapper
,
FakeNixlWrapper
)
assert
isinstance
(
connector
.
connector_worker
.
nixl_wrapper
,
FakeNixlWrapper
)
connector
.
connector_worker
.
nixl_wrapper
.
set_cycles_before_xfer_done
(
3
)
worker
=
connector
.
connector_worker
worker
.
nixl_wrapper
.
set_cycles_before_xfer_done
(
3
)
# simulate handshake
worker
.
dst_xfer_side_handles
=
{
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
:
{
0
:
1
}
}
worker
.
kv_cache_layout
=
"HND"
num_xfers
=
4
num_xfers
=
4
while
True
:
while
True
:
# For the same request_id, initiate multiple xfers across different
# For the same request_id, initiate multiple xfers across different
...
@@ -567,6 +596,171 @@ class TestNixlHandshake:
...
@@ -567,6 +596,171 @@ class TestNixlHandshake:
return
return
raise
TimeoutError
(
"Took too long to complete async handshake."
)
raise
TimeoutError
(
"Took too long to complete async handshake."
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
,
)
@
pytest
.
mark
.
parametrize
(
"local_tp_size"
,
[
1
,
2
])
def
test_prefill_tp_size_greater_than_decode_tp_size
(
self
,
local_tp_size
:
int
,
dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations.
"""
vllm_config
=
create_vllm_config
()
local_tp_size
=
1
vllm_config
.
parallel_config
.
tensor_parallel_size
=
local_tp_size
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
worker
=
connector
.
connector_worker
# Minimal local registration params used by add_remote_agent
worker
.
slot_size_per_layer
=
[
4096
]
worker
.
block_len_per_layer
=
[
4096
*
worker
.
block_size
]
worker
.
num_blocks
=
1
worker
.
dst_num_blocks
[
worker
.
engine_id
]
=
worker
.
num_blocks
worker
.
src_blocks_data
=
[(
0
,
worker
.
block_len_per_layer
[
0
],
worker
.
tp_rank
)]
def
check_handshake
(
remote_tp_size
:
int
):
tp_ratio
=
remote_tp_size
//
local_tp_size
assert
set
(
remote_agents
.
keys
())
==
set
(
range
(
tp_ratio
))
remote_engine_id
=
worker
.
REMOTE_ENGINE_ID
assert
worker
.
_tp_size
[
remote_engine_id
]
==
remote_tp_size
assert
-
tp_ratio
==
worker
.
kv_topo
.
tp_ratio_from_engine_id
(
remote_engine_id
)
# ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks
assert
-
tp_ratio
in
worker
.
src_xfer_handles_by_tp_ratio
assert
len
(
worker
.
src_xfer_handles_by_tp_ratio
[
-
tp_ratio
])
==
tp_ratio
assert
remote_engine_id
in
worker
.
dst_xfer_side_handles
assert
set
(
worker
.
dst_xfer_side_handles
[
remote_engine_id
].
keys
())
==
set
(
range
(
tp_ratio
)
)
remote_agents
=
worker
.
_nixl_handshake
(
host
=
"localhost"
,
port
=
1234
,
remote_tp_size
=
2
,
expected_engine_id
=
worker
.
REMOTE_ENGINE_ID
,
)
check_handshake
(
2
)
# NOTE flexiblity: a second remote with higher number of ranks is
# discovered. This is not a scenario we actively support right now, but
# the connector allows it.
worker
.
REMOTE_ENGINE_ID
=
"remote_engine_2"
remote_agents
=
worker
.
_nixl_handshake
(
host
=
"localhost"
,
port
=
1234
,
remote_tp_size
=
6
,
expected_engine_id
=
worker
.
REMOTE_ENGINE_ID
,
)
check_handshake
(
6
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
,
)
@
pytest
.
mark
.
parametrize
(
"local_tp_size"
,
[
1
,
2
])
def
test_prefill_tp_size_greater_than_decode_tp_size_mla
(
self
,
local_tp_size
:
int
,
dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations for an MLA model.
"""
vllm_config
=
create_vllm_config
()
d_tp_size
=
1
p_tp_size
=
2
# Build two separate connectors/workers to emulate P TP=2 ranks.
conn_p0
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
conn_p1
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
conn_p0
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
conn_p0
.
engine_id
,
hand_shake_latency
=
0
)
conn_p1
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
conn_p1
.
engine_id
,
hand_shake_latency
=
0
)
# Force P world size to 2 for both workers and emulate distinct tp_ranks.
# Also enable MLA path so that expected_finished_count is updated.
for
rank
,
worker
in
enumerate
(
(
conn_p0
.
connector_worker
,
conn_p1
.
connector_worker
)
):
worker
.
world_size
=
p_tp_size
worker
.
kv_topo
.
remote_tp_size
=
{
worker
.
engine_id
:
p_tp_size
}
worker
.
tp_rank
=
rank
worker
.
use_mla
=
True
req_id
=
"req-ep-dp2-p0"
now
=
time
.
perf_counter
()
# Register a request on P that is waiting for consumers to read
# (both workers track it).
conn_p0
.
connector_worker
.
_reqs_to_send
[
req_id
]
=
now
+
10.0
conn_p0
.
connector_worker
.
_reqs_to_process
.
add
(
req_id
)
conn_p1
.
connector_worker
.
_reqs_to_send
[
req_id
]
=
now
+
10.0
conn_p1
.
connector_worker
.
_reqs_to_process
.
add
(
req_id
)
# Simulate a read notification coming from D with (tp=1, dp=2).
notif
=
f
"
{
req_id
}
:
{
d_tp_size
}
"
.
encode
()
# D0-0->P0 notif
conn_p0
.
connector_worker
.
nixl_wrapper
.
get_new_notifs
=
lambda
:
{
"agent"
:
[
notif
]
}
# type: ignore[method-assign]
conn_p1
.
connector_worker
.
nixl_wrapper
.
get_new_notifs
=
lambda
:
{
"agent"
:
[
notif
]
}
# type: ignore[method-assign]
# Trigger notification processing via get_finished().
done_sending0
,
_
=
conn_p0
.
get_finished
(
finished_req_ids
=
set
())
done_sending1
,
_
=
conn_p1
.
get_finished
(
finished_req_ids
=
set
())
assert
req_id
in
done_sending0
and
req_id
in
done_sending1
# E2E aggregation: ensure the aggregated output marks the request
# as finished using the connector's expected_finished_count.
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
aggregator
=
KVOutputAggregator
.
from_connector
(
conn_p0
,
world_size
=
2
)
out0
=
ModelRunnerOutput
(
req_ids
=
[
req_id
],
req_id_to_index
=
{
req_id
:
0
},
sampled_token_ids
=
[[
0
]],
logprobs
=
None
,
prompt_logprobs_dict
=
{},
pooler_output
=
[
None
],
kv_connector_output
=
KVConnectorOutput
(
finished_sending
=
done_sending0
,
finished_recving
=
None
,
),
)
out1
=
ModelRunnerOutput
(
req_ids
=
[
req_id
],
req_id_to_index
=
{
req_id
:
0
},
sampled_token_ids
=
[[
0
]],
logprobs
=
None
,
prompt_logprobs_dict
=
{},
pooler_output
=
[
None
],
kv_connector_output
=
KVConnectorOutput
(
finished_sending
=
done_sending1
,
finished_recving
=
None
,
),
)
aggregated
=
aggregator
.
aggregate
([
out0
,
out1
],
output_rank
=
0
)
assert
aggregated
.
kv_connector_output
is
not
None
assert
aggregated
.
kv_connector_output
.
finished_sending
==
{
req_id
}
# Producers cleaned up state for the finished request.
assert
req_id
not
in
conn_p0
.
connector_worker
.
_reqs_to_send
assert
req_id
not
in
conn_p0
.
connector_worker
.
_reqs_to_process
assert
req_id
not
in
conn_p1
.
connector_worker
.
_reqs_to_send
assert
req_id
not
in
conn_p1
.
connector_worker
.
_reqs_to_process
@
patch
(
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
,
FakeNixlWrapper
,
...
@@ -585,6 +779,9 @@ class TestNixlHandshake:
...
@@ -585,6 +779,9 @@ class TestNixlHandshake:
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
vllm_config
,
connector
.
engine_id
)
)
# Register (mocked) local xfer handler
# worker = connector.connector_worker
# worker.src_xfer_handles_by_block_size = {worker.block_size: 1}
metadata
=
NixlConnectorMetadata
()
metadata
=
NixlConnectorMetadata
()
total_reqs
=
5
total_reqs
=
5
for
i
in
range
(
total_reqs
):
for
i
in
range
(
total_reqs
):
...
@@ -672,7 +869,6 @@ class TestNixlHandshake:
...
@@ -672,7 +869,6 @@ class TestNixlHandshake:
with
pytest
.
raises
(
RuntimeError
):
with
pytest
.
raises
(
RuntimeError
):
# mismatched layout is expected to fail
# mismatched layout is expected to fail
worker
.
add_remote_agent
(
meta
,
remote_tp_size
=
2
)
worker
.
add_remote_agent
(
meta
,
remote_tp_size
=
2
)
with
pytest
.
raises
(
AssertionError
):
worker
.
add_remote_agent
(
meta
,
remote_tp_size
=
1
)
worker
.
add_remote_agent
(
meta
,
remote_tp_size
=
1
)
@
patch
(
@
patch
(
...
@@ -1357,8 +1553,11 @@ def test_shutdown_cleans_up_resources(dist_init):
...
@@ -1357,8 +1553,11 @@ def test_shutdown_cleans_up_resources(dist_init):
patch
.
object
(
nixl_wrapper
,
"deregister_memory"
)
as
mock_dereg
,
patch
.
object
(
nixl_wrapper
,
"deregister_memory"
)
as
mock_dereg
,
):
):
worker
.
_recving_transfers
=
{
"req1"
:
[
123
]}
worker
.
_recving_transfers
=
{
"req1"
:
[
123
]}
worker
.
src_xfer_side_handle
=
456
# Mock register_kv_cache which registers local handle
worker
.
dst_xfer_side_handles
=
{
"engine1"
:
789
}
worker
.
src_xfer_handles_by_block_size
=
{
worker
.
block_size
:
455
}
# P TP = 2 * D TP case, we should register 2 local handles
worker
.
src_xfer_handles_by_tp_ratio
=
{
-
2
:
[
456
,
457
]}
worker
.
dst_xfer_side_handles
=
{
"engine1"
:
{
0
:
789
}}
worker
.
_remote_agents
=
{
"engine1"
:
{
0
:
"agent1"
}}
worker
.
_remote_agents
=
{
"engine1"
:
{
0
:
"agent1"
}}
worker
.
_registered_descs
=
[
"desc1"
,
"desc2"
]
worker
.
_registered_descs
=
[
"desc1"
,
"desc2"
]
...
@@ -1379,8 +1578,10 @@ def test_shutdown_cleans_up_resources(dist_init):
...
@@ -1379,8 +1578,10 @@ def test_shutdown_cleans_up_resources(dist_init):
mock_listener
.
join
.
assert_called_once
()
mock_listener
.
join
.
assert_called_once
()
mock_rel_xfer
.
assert_called_once_with
(
123
)
mock_rel_xfer
.
assert_called_once_with
(
123
)
assert
mock_rel_dlist
.
call_count
==
2
assert
mock_rel_dlist
.
call_count
==
4
mock_rel_dlist
.
assert_any_call
(
456
)
# src handle
mock_rel_dlist
.
assert_any_call
(
455
)
# src handle (whole region)
mock_rel_dlist
.
assert_any_call
(
456
)
# src handle (1st chunk)
mock_rel_dlist
.
assert_any_call
(
457
)
# src handle (2nd chunk)
mock_rel_dlist
.
assert_any_call
(
789
)
# dst handle
mock_rel_dlist
.
assert_any_call
(
789
)
# dst handle
mock_rem_agent
.
assert_called_once_with
(
"agent1"
)
mock_rem_agent
.
assert_called_once_with
(
"agent1"
)
assert
mock_dereg
.
call_count
==
2
assert
mock_dereg
.
call_count
==
2
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
bc3700e0
...
@@ -21,6 +21,8 @@ if TYPE_CHECKING:
...
@@ -21,6 +21,8 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
EngineId
=
str
def
get_kv_connector_cache_layout
():
def
get_kv_connector_cache_layout
():
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
...
@@ -209,12 +211,12 @@ class TpKVTopology:
...
@@ -209,12 +211,12 @@ class TpKVTopology:
"""
"""
tp_rank
:
int
tp_rank
:
int
remote_tp_size
:
dict
[
str
,
int
]
remote_tp_size
:
dict
[
EngineId
,
int
]
is_mla
:
bool
is_mla
:
bool
total_num_kv_heads
:
int
total_num_kv_heads
:
int
attn_backend
:
type
[
AttentionBackend
]
attn_backend
:
type
[
AttentionBackend
]
engine_id
:
str
engine_id
:
EngineId
remote_block_size
:
dict
[
str
,
int
]
remote_block_size
:
dict
[
EngineId
,
int
]
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Figure out whether the first dimension of the cache is K/V
# Figure out whether the first dimension of the cache is K/V
...
@@ -256,18 +258,28 @@ class TpKVTopology:
...
@@ -256,18 +258,28 @@ class TpKVTopology:
Calculate the tensor parallel ratio between local and remote TP.
Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP
We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in
workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`.
groups of size `tp_ratio`.If remote tp_size > local tp_size, the
ratio is flipped (remote_size/local_size) and the returned value is
negative.
"""
"""
assert
self
.
tp_size
%
remote_tp_size
==
0
,
(
if
self
.
tp_size
>=
remote_tp_size
:
f
"Local tensor parallel size
{
self
.
tp_size
}
is not divisible "
assert
self
.
tp_size
%
remote_tp_size
==
0
,
(
f
"by remote tensor parallel size
{
remote_tp_size
}
."
f
"Local tensor parallel size
{
self
.
tp_size
}
is not divisible "
f
"by remote tensor parallel size
{
remote_tp_size
}
."
)
return
self
.
tp_size
//
remote_tp_size
assert
remote_tp_size
%
self
.
tp_size
==
0
,
(
f
"Remote tensor parallel size
{
remote_tp_size
}
is not divisible "
f
"by local tensor parallel size
{
self
.
tp_size
}
."
)
)
return
self
.
tp_size
//
remote_tp_size
# P TP > D TP case, return the ratio as negative
return
-
remote_tp_size
//
self
.
tp_size
def
block_size_ratio
(
def
block_size_ratio
(
self
,
self
,
remote_block_size
:
int
,
remote_block_size
:
int
,
)
->
floa
t
:
)
->
in
t
:
"""
"""
Calculate the block size ratio between local and remote TP.
Calculate the block size ratio between local and remote TP.
"""
"""
...
@@ -279,19 +291,19 @@ class TpKVTopology:
...
@@ -279,19 +291,19 @@ class TpKVTopology:
def
tp_ratio_from_engine_id
(
def
tp_ratio_from_engine_id
(
self
,
self
,
remote_engine_id
:
str
,
remote_engine_id
:
EngineId
,
)
->
int
:
)
->
int
:
remote_tp_size
=
self
.
remote_tp_size
[
remote_engine_id
]
remote_tp_size
=
self
.
remote_tp_size
[
remote_engine_id
]
return
self
.
tp_ratio
(
remote_tp_size
)
return
self
.
tp_ratio
(
remote_tp_size
)
def
block_size_ratio_from_engine_id
(
def
block_size_ratio_from_engine_id
(
self
,
self
,
remote_engine_id
:
str
,
remote_engine_id
:
EngineId
,
)
->
floa
t
:
)
->
in
t
:
remote_block_size
=
self
.
remote_block_size
[
remote_engine_id
]
remote_block_size
=
self
.
remote_block_size
[
remote_engine_id
]
return
self
.
block_size_ratio
(
remote_block_size
)
return
self
.
block_size_ratio
(
remote_block_size
)
def
is_kv_replicated
(
self
,
engine_id
:
str
)
->
bool
:
def
is_kv_replicated
(
self
,
engine_id
:
EngineId
)
->
bool
:
"""
"""
Whether the KV cache is replicated across TP workers due to the
Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads.
number of TP workers being greater than the number of KV heads.
...
@@ -299,24 +311,30 @@ class TpKVTopology:
...
@@ -299,24 +311,30 @@ class TpKVTopology:
tp_size
=
self
.
remote_tp_size
[
engine_id
]
tp_size
=
self
.
remote_tp_size
[
engine_id
]
return
tp_size
//
self
.
total_num_kv_heads
>=
1
return
tp_size
//
self
.
total_num_kv_heads
>=
1
def
replicates_kv_cache
(
self
,
remote_engine_id
:
str
)
->
bool
:
def
replicates_kv_cache
(
self
,
remote_engine_id
:
EngineId
)
->
bool
:
# MLA is always replicated as the hidden dim can't be split.
# MLA is always replicated as the hidden dim can't be split.
return
self
.
is_mla
or
self
.
is_kv_replicated
(
remote_engine_id
)
return
self
.
is_mla
or
self
.
is_kv_replicated
(
remote_engine_id
)
def
get_target_remote_rank
(
def
get_target_remote_rank
s
(
self
,
self
,
remote_tp_size
:
int
,
remote_tp_size
:
int
,
)
->
int
:
)
->
list
[
int
]
:
"""
"""
Get the remote TP rank (on P) that the current local TP rank
Get the remote TP rank (on P) that the current local TP rank
(on D) will read from.
(on D) will read from. When remote tp_size > local tp_size, we
read from multiple remote ranks.
"""
"""
tp_ratio
=
self
.
tp_ratio
(
remote_tp_size
)
tp_ratio
=
self
.
tp_ratio
(
remote_tp_size
)
return
self
.
tp_rank
//
tp_ratio
if
tp_ratio
>
0
:
return
[
self
.
tp_rank
//
tp_ratio
]
def
get_target_remote_rank_from_engine_id
(
# P TP > D TP case, D reads from |tp_ratio| remote workers.
tp_ratio
=
-
tp_ratio
return
[
self
.
tp_rank
*
tp_ratio
+
i
for
i
in
range
(
tp_ratio
)]
def
get_target_remote_ranks_from_engine_id
(
self
,
self
,
remote_engine_id
:
str
,
remote_engine_id
:
EngineId
,
)
->
int
:
)
->
list
[
int
]
:
remote_tp_size
=
self
.
remote_tp_size
[
remote_engine_id
]
remote_tp_size
=
self
.
remote_tp_size
[
remote_engine_id
]
return
self
.
get_target_remote_rank
(
remote_tp_size
)
return
self
.
get_target_remote_rank
s
(
remote_tp_size
)
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
bc3700e0
...
@@ -23,7 +23,7 @@ from vllm import envs
...
@@ -23,7 +23,7 @@ from vllm import envs
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.selector
import
get_attn_backend
from
vllm.attention.selector
import
get_attn_backend
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.utils
import
TpKVTopology
from
vllm.distributed.kv_transfer.kv_connector.utils
import
EngineId
,
TpKVTopology
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
CopyBlocksOp
,
CopyBlocksOp
,
KVConnectorBase_V1
,
KVConnectorBase_V1
,
...
@@ -56,7 +56,6 @@ if TYPE_CHECKING:
...
@@ -56,7 +56,6 @@ if TYPE_CHECKING:
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
TransferHandle
=
int
TransferHandle
=
int
EngineId
=
str
ReqId
=
str
ReqId
=
str
#
#
...
@@ -873,9 +872,10 @@ class NixlConnectorWorker:
...
@@ -873,9 +872,10 @@ class NixlConnectorWorker:
self
.
copy_blocks
:
CopyBlocksOp
|
None
=
None
self
.
copy_blocks
:
CopyBlocksOp
|
None
=
None
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
# rank will still only pull from a single remote TP worker.
self
.
kv_caches_base_addr
:
dict
[
EngineId
,
list
[
int
]]
=
{}
self
.
device_id
:
int
=
0
self
.
device_id
:
int
=
0
# Current rank may pull from multiple remote TP workers.
# EngineId, dict[int, list[int]] -> engine_id, tp_rank, base_addr_for_layer
self
.
kv_caches_base_addr
=
defaultdict
[
EngineId
,
dict
[
int
,
list
[
int
]]](
dict
)
# Number of NIXL regions. Currently one region per cache
# Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer)
# (so 1 per layer for MLA, otherwise 2 per layer)
...
@@ -883,10 +883,12 @@ class NixlConnectorWorker:
...
@@ -883,10 +883,12 @@ class NixlConnectorWorker:
self
.
num_layers
=
0
self
.
num_layers
=
0
# nixl_prepped_dlist_handle.
# nixl_prepped_dlist_handle.
self
.
src_xfer_side_handle
:
int
=
0
self
.
src_xfer_handles_by_block_size
:
dict
[
int
,
int
]
=
{}
self
.
src_xfer_side_handles
:
dict
[
int
,
int
]
=
{}
# Populated dynamically during handshake based on remote configuration.
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
# Keep track of regions at different tp_ratio values. tp_ratio->handles
self
.
dst_xfer_side_handles
:
dict
[
EngineId
,
int
]
=
{}
self
.
src_xfer_handles_by_tp_ratio
:
dict
[
int
,
list
[
int
]]
=
{}
# Map of engine_id -> {tp_rank: nixl_prepped_dlist_handle (int)}.
self
.
dst_xfer_side_handles
=
defaultdict
[
EngineId
,
dict
[
int
,
int
]](
dict
)
# Map of engine_id -> num_blocks. All ranks in the same deployment will
# Map of engine_id -> num_blocks. All ranks in the same deployment will
# have the same number of blocks.
# have the same number of blocks.
...
@@ -977,103 +979,108 @@ class NixlConnectorWorker:
...
@@ -977,103 +979,108 @@ class NixlConnectorWorker:
expected_engine_id
:
str
,
expected_engine_id
:
str
,
)
->
dict
[
int
,
str
]:
)
->
dict
[
int
,
str
]:
"""Do a NIXL handshake with a remote instance."""
"""Do a NIXL handshake with a remote instance."""
# When target instance TP > local TP, we need to perform multiple
start_time
=
time
.
perf_counter
()
# handshakes. Do it in a single background job for simplicity.
# Regardless, only handshake with the remote TP rank(s) that current
# NOTE(rob): we need each rank to have a unique port. This is
# local rank will read from. Note that With homogeneous TP,
# a hack to keep us moving. We will switch when moving to etcd
# this happens to be the same single rank_i.
# or where we have a single ZMQ socket in the scheduler.
p_remote_ranks
=
self
.
kv_topo
.
get_target_remote_ranks
(
remote_tp_size
)
remote_rank_to_agent_name
=
{}
# Handshake only with the remote TP rank that current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
p_remote_rank
=
self
.
kv_topo
.
get_target_remote_rank
(
remote_tp_size
)
path
=
make_zmq_path
(
"tcp"
,
host
,
port
)
path
=
make_zmq_path
(
"tcp"
,
host
,
port
)
logger
.
debug
(
"Querying metadata on path: %s at remote tp rank %s"
,
path
,
p_remote_rank
)
# Send query for the request.
with
zmq_ctx
(
zmq
.
REQ
,
path
)
as
sock
:
with
zmq_ctx
(
zmq
.
REQ
,
path
)
as
sock
:
msg
=
msgspec
.
msgpack
.
encode
((
GET_META_MSG
,
p_remote_rank
))
for
remote_rank
in
p_remote_ranks
:
# Set receive timeout to 5 seconds to avoid hanging on dead server
logger
.
debug
(
sock
.
setsockopt
(
zmq
.
RCVTIMEO
,
5000
)
# milliseconds
"Querying metadata on path: %s at remote tp rank %s"
,
sock
.
send
(
msg
)
path
,
handshake_bytes
=
sock
.
recv
()
remote_rank
,
# Decode handshake payload to get compatibility hash
handshake_decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlHandshakePayload
)
try
:
handshake_payload
=
handshake_decoder
.
decode
(
handshake_bytes
)
except
(
msgspec
.
DecodeError
,
msgspec
.
ValidationError
)
as
e
:
raise
RuntimeError
(
f
"Failed to decode NixlHandshakePayload. This likely indicates "
f
"an incompatibility between connector version. Error:
{
e
}
"
)
from
e
got_metadata_time
=
time
.
perf_counter
()
logger
.
debug
(
"NIXL handshake: get metadata took: %s"
,
got_metadata_time
-
start_time
)
# Check compatibility hash BEFORE decoding agent metadata
if
(
self
.
enforce_compat_hash
and
handshake_payload
.
compatibility_hash
!=
self
.
compat_hash
):
raise
RuntimeError
(
f
"NIXL compatibility hash mismatch. "
f
"Local:
{
self
.
compat_hash
}
, "
f
"Remote:
{
handshake_payload
.
compatibility_hash
}
. "
f
"Prefill and decode instances have incompatible configurations. "
f
"This may be due to: different vLLM versions, models, dtypes, "
f
"KV cache layouts, attention backends, etc. "
f
"Both instances must use identical configurations."
f
"Disable this check using "
f
'--kv-transfer-config
\'
{{"kv_connector_extra_config": '
f
'{{"enforce_handshake_compat": false}}}}
\'
'
)
)
logger
.
info
(
start_time
=
time
.
perf_counter
()
"NIXL compatibility check passed (hash: %s)"
,
# Send query for the request.
handshake_payload
.
compatibility_hash
,
msg
=
msgspec
.
msgpack
.
encode
((
GET_META_MSG
,
remote_rank
))
)
# Set receive timeout to 5 seconds to avoid hanging on dead server
sock
.
setsockopt
(
zmq
.
RCVTIMEO
,
5000
)
# milliseconds
sock
.
send
(
msg
)
handshake_bytes
=
sock
.
recv
()
# Decode agent metadata
# Decode handshake payload to get compatibility hash
metadata_decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlAgentMetadata
)
handshake_decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlHandshakePayload
)
try
:
try
:
metadata
=
metadata_decoder
.
decode
(
handshake_payload
=
handshake_decoder
.
decode
(
handshake_bytes
)
handshake_payload
.
agent_metadata_bytes
except
(
msgspec
.
DecodeError
,
msgspec
.
ValidationError
)
as
e
:
raise
RuntimeError
(
f
"Failed to decode NixlHandshakePayload. This likely indicates "
f
"an incompatibility between connector version. Error:
{
e
}
"
)
from
e
got_metadata_time
=
time
.
perf_counter
()
logger
.
debug
(
"NIXL handshake: get metadata took: %s"
,
got_metadata_time
-
start_time
,
)
)
except
(
msgspec
.
DecodeError
,
msgspec
.
ValidationError
)
as
e
:
# This should not happen if hash matched
raise
RuntimeError
(
f
"Failed to decode NixlAgentMetadata. Error:
{
e
}
"
)
from
e
# Ensure engine id matches.
# Check compatibility hash BEFORE decoding agent metadata
if
metadata
.
engine_id
!=
expected_engine_id
:
if
(
raise
RuntimeError
(
self
.
enforce_compat_hash
f
"Remote NIXL agent engine ID mismatch. "
and
handshake_payload
.
compatibility_hash
!=
self
.
compat_hash
f
"Expected
{
expected_engine_id
}
,"
):
f
"received
{
metadata
.
engine_id
}
."
raise
RuntimeError
(
)
f
"NIXL compatibility hash mismatch. "
f
"Local:
{
self
.
compat_hash
}
, "
f
"Remote:
{
handshake_payload
.
compatibility_hash
}
. "
f
"Prefill and decode instances have incompatible "
f
"configurations. This may be due to: different vLLM versions,"
f
" models, dtypes, KV cache layouts, attention backends, etc. "
f
"Both instances must use identical configurations."
f
"Disable this check using "
f
'--kv-transfer-config
\'
{{"kv_connector_extra_config": '
f
'{{"enforce_handshake_compat": false}}}}
\'
'
)
# Register Remote agent.
logger
.
info
(
assert
metadata
.
block_size
<=
self
.
block_size
,
(
"NIXL compatibility check passed (hash: %s)"
,
"nP > nD is not supported yet."
handshake_payload
.
compatibility_hash
,
)
)
remote_agent_name
=
self
.
add_remote_agent
(
metadata
,
p_remote_rank
,
remote_tp_size
)
setup_agent_time
=
time
.
perf_counter
()
# Decode agent metadata
logger
.
debug
(
metadata_decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlAgentMetadata
)
"NIXL handshake: add agent took: %s"
,
try
:
setup_agent_time
-
got_metadata_time
,
metadata
=
metadata_decoder
.
decode
(
)
handshake_payload
.
agent_metadata_bytes
)
except
(
msgspec
.
DecodeError
,
msgspec
.
ValidationError
)
as
e
:
# This should not happen if hash matched
raise
RuntimeError
(
f
"Failed to decode NixlAgentMetadata. Error:
{
e
}
"
)
from
e
# Ensure engine id matches.
if
metadata
.
engine_id
!=
expected_engine_id
:
raise
RuntimeError
(
f
"Remote NIXL agent engine ID mismatch. "
f
"Expected
{
expected_engine_id
}
,"
f
"received
{
metadata
.
engine_id
}
."
)
# Ensure engine id matches.
if
metadata
.
engine_id
!=
expected_engine_id
:
raise
RuntimeError
(
f
"Remote NIXL agent engine ID mismatch. "
f
"Expected
{
expected_engine_id
}
,"
f
"received
{
metadata
.
engine_id
}
."
)
setup_agent_time
=
time
.
perf_counter
()
# Remote rank -> agent name.
# Register Remote agent.
return
{
p_remote_rank
:
remote_agent_name
}
remote_agent_name
=
self
.
add_remote_agent
(
metadata
,
remote_rank
,
remote_tp_size
)
logger
.
debug
(
"NIXL handshake: add agent took: %s"
,
setup_agent_time
-
got_metadata_time
,
)
remote_rank_to_agent_name
[
remote_rank
]
=
remote_agent_name
return
remote_rank_to_agent_name
def
initialize_host_xfer_buffer
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
])
->
None
:
def
initialize_host_xfer_buffer
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
])
->
None
:
"""
"""
...
@@ -1283,7 +1290,7 @@ class NixlConnectorWorker:
...
@@ -1283,7 +1290,7 @@ class NixlConnectorWorker:
assert
len
(
self
.
block_len_per_layer
)
==
len
(
seen_base_addresses
)
assert
len
(
self
.
block_len_per_layer
)
==
len
(
seen_base_addresses
)
assert
self
.
num_blocks
!=
0
assert
self
.
num_blocks
!=
0
self
.
kv_caches_base_addr
[
self
.
engine_id
]
=
seen_base_addresses
self
.
kv_caches_base_addr
[
self
.
engine_id
]
[
self
.
tp_rank
]
=
seen_base_addresses
self
.
num_regions
=
len
(
caches_data
)
self
.
num_regions
=
len
(
caches_data
)
self
.
num_layers
=
len
(
xfer_buffers
.
keys
())
self
.
num_layers
=
len
(
xfer_buffers
.
keys
())
...
@@ -1310,9 +1317,9 @@ class NixlConnectorWorker:
...
@@ -1310,9 +1317,9 @@ class NixlConnectorWorker:
# Register local/src descr for NIXL xfer.
# Register local/src descr for NIXL xfer.
self
.
seen_base_addresses
=
seen_base_addresses
self
.
seen_base_addresses
=
seen_base_addresses
self
.
src_xfer_
side_
handle
=
self
.
register_local_xfer_handler
(
self
.
block_size
)
self
.
src_xfer_handle
s_by_block_size
[
self
.
block_size
],
self
.
src_blocks_data
=
(
self
.
register_local_xfer_handler
(
self
.
block_size
)
self
.
src_xfer_side_handles
[
self
.
block_size
]
=
self
.
src_xfer_side_handle
)
# TODO(mgoin): Hybrid memory allocator is currently disabled for
# TODO(mgoin): Hybrid memory allocator is currently disabled for
# models with local attention (Llama 4). Can remove this once enabled.
# models with local attention (Llama 4). Can remove this once enabled.
...
@@ -1340,8 +1347,8 @@ class NixlConnectorWorker:
...
@@ -1340,8 +1347,8 @@ class NixlConnectorWorker:
agent_metadata
=
NixlAgentMetadata
(
agent_metadata
=
NixlAgentMetadata
(
engine_id
=
self
.
engine_id
,
engine_id
=
self
.
engine_id
,
agent_metadata
=
self
.
nixl_wrapper
.
get_agent_metadata
(),
agent_metadata
=
self
.
nixl_wrapper
.
get_agent_metadata
(),
kv_caches_base_addr
=
self
.
kv_caches_base_addr
[
self
.
engine_id
],
device_id
=
self
.
device_id
,
device_id
=
self
.
device_id
,
kv_caches_base_addr
=
self
.
kv_caches_base_addr
[
self
.
engine_id
][
self
.
tp_rank
],
num_blocks
=
self
.
num_blocks
,
num_blocks
=
self
.
num_blocks
,
block_lens
=
self
.
block_len_per_layer
,
block_lens
=
self
.
block_len_per_layer
,
kv_cache_layout
=
self
.
kv_cache_layout
kv_cache_layout
=
self
.
kv_cache_layout
...
@@ -1359,7 +1366,7 @@ class NixlConnectorWorker:
...
@@ -1359,7 +1366,7 @@ class NixlConnectorWorker:
def
register_local_xfer_handler
(
def
register_local_xfer_handler
(
self
,
self
,
block_size
:
int
,
block_size
:
int
,
)
->
int
:
)
->
tuple
[
int
,
list
[
tuple
[
int
,
int
,
int
]]]
:
"""
"""
Function used for register local xfer handler with local block_size or
Function used for register local xfer handler with local block_size or
Remote block_size.
Remote block_size.
...
@@ -1407,7 +1414,7 @@ class NixlConnectorWorker:
...
@@ -1407,7 +1414,7 @@ class NixlConnectorWorker:
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
# NIXL_INIT_AGENT to be used for preparations of local descs.
# NIXL_INIT_AGENT to be used for preparations of local descs.
return
self
.
nixl_wrapper
.
prep_xfer_dlist
(
"NIXL_INIT_AGENT"
,
descs
)
return
self
.
nixl_wrapper
.
prep_xfer_dlist
(
"NIXL_INIT_AGENT"
,
descs
)
,
blocks_data
def
add_remote_agent
(
def
add_remote_agent
(
self
,
self
,
...
@@ -1421,10 +1428,12 @@ class NixlConnectorWorker:
...
@@ -1421,10 +1428,12 @@ class NixlConnectorWorker:
In particular, handle both homogeneous and heterogeneous TP. The former
In particular, handle both homogeneous and heterogeneous TP. The former
requires local rank_i to read from remote rank_i.
requires local rank_i to read from remote rank_i.
The latter, assuming D.world_size > P.world_size, requires that two or
The latter, in the case of D.world_size < P.world_size, requires that a
more local TP worker share the xfer from a single TP worker.
local (D) TP worker reads from multiple remote (P) TP workers.
Conversely, assuming D.world_size > P.world_size, two or more local TP
workers will read from a single remote TP worker.
Here's an example (non-MLA
case
):
Here's an example
for the last case described above
(non-MLA):
rank_offset p_remote_tp_rank
rank_offset p_remote_tp_rank
(kv split no)
(kv split no)
...
@@ -1474,9 +1483,6 @@ class NixlConnectorWorker:
...
@@ -1474,9 +1483,6 @@ class NixlConnectorWorker:
nixl_agent_meta
.
agent_metadata
nixl_agent_meta
.
agent_metadata
)
)
# Handle tp_size>num_kv_heads: replicate KV cache.
replicates_kv_cache
=
self
.
kv_topo
.
replicates_kv_cache
(
engine_id
)
# Create dst descs and xfer side handles. TP workers have same #blocks
# Create dst descs and xfer side handles. TP workers have same #blocks
# so we only register once per engine_id.
# so we only register once per engine_id.
# Example:
# Example:
...
@@ -1490,14 +1496,52 @@ class NixlConnectorWorker:
...
@@ -1490,14 +1496,52 @@ class NixlConnectorWorker:
self
.
dst_num_blocks
[
engine_id
]
=
nixl_agent_meta
.
num_blocks
self
.
dst_num_blocks
[
engine_id
]
=
nixl_agent_meta
.
num_blocks
# Keep track of remote agent kv caches base addresses.
# Keep track of remote agent kv caches base addresses.
self
.
kv_caches_base_addr
[
engine_id
]
=
nixl_agent_meta
.
kv_caches_base_addr
self
.
kv_caches_base_addr
[
engine_id
][
remote_tp_rank
]
=
(
nixl_agent_meta
.
kv_caches_base_addr
)
self
.
_validate_remote_agent_handshake
(
nixl_agent_meta
,
remote_tp_size
)
self
.
_validate_remote_agent_handshake
(
nixl_agent_meta
,
remote_tp_size
)
#
Number of D TP workers reading from a single P TP worker. This is
#
This is 1 when P and D `--tensor-parallel-size` match. Otherwise,
#
1 when P and D `--tensor-parallel-size` match
.
#
this is the ratio between the two sizes
.
tp_ratio
=
self
.
kv_topo
.
tp_ratio_from_engine_id
(
engine_id
)
tp_ratio
=
self
.
kv_topo
.
tp_ratio_from_engine_id
(
engine_id
)
# Handle tp_size>num_kv_heads: replicate KV cache.
indexes_into_remote
=
(
not
self
.
kv_topo
.
replicates_kv_cache
(
engine_id
)
and
tp_ratio
>
0
)
logger
.
debug
(
"Registering remote agent (%s, rank %s) memory regions with tp_ratio %s"
,
engine_id
,
remote_tp_rank
,
tp_ratio
,
)
### (Optional) Register local agent memory regions. MLA is not split.
if
(
tp_ratio
<
0
and
not
self
.
use_mla
and
tp_ratio
not
in
self
.
src_xfer_handles_by_tp_ratio
):
# Remote tp_size > local tp_size: read from multiple remote ranks.
# Logically "split" own regions into |tp_ratio| chunks. Mind that
# we only do this once per remote tp_size (replica-friendly).
self
.
src_xfer_handles_by_tp_ratio
[
tp_ratio
]
=
[]
for
i
in
range
(
-
tp_ratio
):
blocks_data
=
[]
for
memory_region
in
self
.
src_blocks_data
:
addr
,
local_block_len
,
own_tp_rank
=
memory_region
# Computing block len layer by layer allows for different
# block sizes to be used.
remote_block_len
=
local_block_len
//
(
-
tp_ratio
)
addr
=
addr
+
i
*
remote_block_len
blocks_data
.
append
((
addr
,
remote_block_len
,
own_tp_rank
))
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
handle
=
self
.
nixl_wrapper
.
prep_xfer_dlist
(
"NIXL_INIT_AGENT"
,
descs
)
self
.
src_xfer_handles_by_tp_ratio
[
tp_ratio
].
append
(
handle
)
### Register remote agent memory regions
### Register remote agent memory regions
blocks_data
=
[]
blocks_data
=
[]
# With homogeneous TP, D pulls the whole kv cache from corresponding
# With homogeneous TP, D pulls the whole kv cache from corresponding
...
@@ -1507,14 +1551,19 @@ class NixlConnectorWorker:
...
@@ -1507,14 +1551,19 @@ class NixlConnectorWorker:
# Register all remote blocks, but only the corresponding kv heads.
# Register all remote blocks, but only the corresponding kv heads.
for
i
,
base_addr
in
enumerate
(
nixl_agent_meta
.
kv_caches_base_addr
):
for
i
,
base_addr
in
enumerate
(
nixl_agent_meta
.
kv_caches_base_addr
):
kv_block_len
=
self
.
get_backend_aware_kv_block_len
(
layer_idx
=
i
)
# Read our whole local region size from remote.
remote_kv_block_len
=
kv_block_len
//
block_size_ratio
local_block_len
=
self
.
get_backend_aware_kv_block_len
(
layer_idx
=
i
)
remote_kv_block_len
=
local_block_len
//
block_size_ratio
if
block_size_ratio
>
1
:
if
block_size_ratio
>
1
:
# using remote kv_block_len as transfer unit
# using remote kv_block_len as transfer unit
kv_block_len
=
remote_kv_block_len
local_block_len
=
remote_kv_block_len
if
tp_ratio
<
0
and
not
self
.
use_mla
:
# Remote tp is bigger: read a chunk of local region from remote
local_block_len
=
local_block_len
//
(
-
tp_ratio
)
rank_offset
=
(
rank_offset
=
(
self
.
tp_rank
%
tp_ratio
*
remote_kv_block_len
self
.
tp_rank
%
tp_ratio
*
remote_kv_block_len
if
not
replicates_kv_cach
e
if
indexes_into_remot
e
else
0
else
0
)
)
for
block_id
in
range
(
nixl_agent_meta
.
num_blocks
):
for
block_id
in
range
(
nixl_agent_meta
.
num_blocks
):
...
@@ -1524,7 +1573,7 @@ class NixlConnectorWorker:
...
@@ -1524,7 +1573,7 @@ class NixlConnectorWorker:
# self.block_len == remote_block_len//tp_ratio bytes.
# self.block_len == remote_block_len//tp_ratio bytes.
addr
=
base_addr
+
block_offset
+
rank_offset
addr
=
base_addr
+
block_offset
+
rank_offset
# (addr, len, device id)
# (addr, len, device id)
blocks_data
.
append
((
addr
,
kv
_block_len
,
nixl_agent_meta
.
device_id
))
blocks_data
.
append
((
addr
,
local
_block_len
,
nixl_agent_meta
.
device_id
))
if
self
.
kv_topo
.
is_kv_layout_blocks_first
:
if
self
.
kv_topo
.
is_kv_layout_blocks_first
:
# With FlashInfer index V separately to allow head splitting.
# With FlashInfer index V separately to allow head splitting.
...
@@ -1533,7 +1582,7 @@ class NixlConnectorWorker:
...
@@ -1533,7 +1582,7 @@ class NixlConnectorWorker:
addr
=
base_addr
+
block_offset
+
rank_offset
addr
=
base_addr
+
block_offset
+
rank_offset
v_addr
=
addr
+
nixl_agent_meta
.
block_lens
[
i
]
//
2
v_addr
=
addr
+
nixl_agent_meta
.
block_lens
[
i
]
//
2
blocks_data
.
append
(
blocks_data
.
append
(
(
v_addr
,
kv
_block_len
,
nixl_agent_meta
.
device_id
)
(
v_addr
,
local
_block_len
,
nixl_agent_meta
.
device_id
)
)
)
logger
.
debug
(
logger
.
debug
(
...
@@ -1546,15 +1595,15 @@ class NixlConnectorWorker:
...
@@ -1546,15 +1595,15 @@ class NixlConnectorWorker:
# Register with NIXL.
# Register with NIXL.
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
self
.
dst_xfer_side_handles
[
engine_id
]
=
self
.
nixl_wrapper
.
prep_xfer_dlist
(
self
.
dst_xfer_side_handles
[
engine_id
]
[
remote_tp_rank
]
=
(
remote_agent_name
,
descs
self
.
nixl_wrapper
.
prep_xfer_dlist
(
remote_agent_name
,
descs
)
)
)
if
block_size_ratio
>
1
:
if
block_size_ratio
>
1
:
# when prefill with smaller block_size, we need to init a
# when prefill with smaller block_size, we need to init a
# new handler with same block_len to match
# new handler with same block_len to match
self
.
src_xfer_
side_
handles
[
nixl_agent_meta
.
block_size
]
=
(
self
.
src_xfer_handles
_by_block_size
[
nixl_agent_meta
.
block_size
]
=
(
self
.
register_local_xfer_handler
(
nixl_agent_meta
.
block_size
)
self
.
register_local_xfer_handler
(
nixl_agent_meta
.
block_size
)
[
0
]
)
)
return
remote_agent_name
return
remote_agent_name
...
@@ -1574,7 +1623,9 @@ class NixlConnectorWorker:
...
@@ -1574,7 +1623,9 @@ class NixlConnectorWorker:
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
remote_engine_id
remote_engine_id
)
)
assert
tp_ratio
>
0
,
"Decode TP cannot be smaller than prefill TP"
# Num kv_heads > tp_size and P TP > D TP case, not supported
assert
not
(
tp_ratio
<
0
and
self
.
kv_topo
.
is_kv_replicated
(
remote_engine_id
))
assert
not
self
.
_use_pallas
or
tp_ratio
==
1
,
(
assert
not
self
.
_use_pallas
or
tp_ratio
==
1
,
(
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
)
)
...
@@ -1616,17 +1667,29 @@ class NixlConnectorWorker:
...
@@ -1616,17 +1667,29 @@ class NixlConnectorWorker:
"All remote layers must have the same block size"
"All remote layers must have the same block size"
)
)
assert
(
if
tp_ratio
>
0
:
remote_block_len
# Remote tp is smaller: remote block_len size is bigger
==
(
self
.
block_len_per_layer
[
0
]
*
tp_ratio
)
//
block_size_ratio
assert
(
),
(
remote_block_len
"Remote P worker KV layer cache must be of shape [2, N, "
==
(
self
.
block_len_per_layer
[
0
]
*
tp_ratio
)
//
block_size_ratio
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
),
(
)
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, page_size, head_dim] and same dtype."
)
# noqa: E501
else
:
assert
block_size_ratio
==
1
,
(
"Different local/remote block sizes are not supported when"
" P TP > D TP."
)
# Remote tp is bigger: remote block_len size is smaller
assert
remote_block_len
==
self
.
block_len_per_layer
[
0
]
//
(
-
tp_ratio
),
(
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads/tp_ratio, page_size, head_dim] and same dtype."
)
# noqa: E501
# TP workers have same #blocks.
# TP workers
that handhshake with same remote
have same #blocks.
assert
self
.
dst_num_blocks
[
remote_engine_id
]
==
nixl_agent_meta
.
num_blocks
assert
self
.
dst_num_blocks
[
remote_engine_id
]
==
nixl_agent_meta
.
num_blocks
# Same number of regions/~layers.
assert
len
(
nixl_agent_meta
.
kv_caches_base_addr
)
==
len
(
self
.
block_len_per_layer
)
assert
len
(
nixl_agent_meta
.
kv_caches_base_addr
)
==
len
(
self
.
block_len_per_layer
)
def
sync_recved_kv_to_device
(
self
,
req_id
:
str
,
meta
:
ReqMeta
):
def
sync_recved_kv_to_device
(
self
,
req_id
:
str
,
meta
:
ReqMeta
):
...
@@ -1710,7 +1773,7 @@ class NixlConnectorWorker:
...
@@ -1710,7 +1773,7 @@ class NixlConnectorWorker:
)
)
cache
.
index_copy_
(
0
,
indices
,
permuted_blocks
)
cache
.
index_copy_
(
0
,
indices
,
permuted_blocks
)
def
blocksize_post_process
(
self
,
block_ids_per_ratio
:
dict
[
floa
t
,
list
[
list
[
int
]]]):
def
blocksize_post_process
(
self
,
block_ids_per_ratio
:
dict
[
in
t
,
list
[
list
[
int
]]]):
def
_process_local_gt_remote
(
blocks_to_update
,
block_size_ratio
):
def
_process_local_gt_remote
(
blocks_to_update
,
block_size_ratio
):
n_kv_heads
,
block_size
,
head_size
=
blocks_to_update
.
shape
[
1
:]
n_kv_heads
,
block_size
,
head_size
=
blocks_to_update
.
shape
[
1
:]
remote_block_size
=
block_size
//
block_size_ratio
remote_block_size
=
block_size
//
block_size_ratio
...
@@ -1840,7 +1903,7 @@ class NixlConnectorWorker:
...
@@ -1840,7 +1903,7 @@ class NixlConnectorWorker:
notified_req_ids
:
set
[
str
]
=
set
()
notified_req_ids
:
set
[
str
]
=
set
()
for
notifs
in
self
.
nixl_wrapper
.
get_new_notifs
().
values
():
for
notifs
in
self
.
nixl_wrapper
.
get_new_notifs
().
values
():
for
notif
in
notifs
:
for
notif
in
notifs
:
req_id
,
tp_
ratio
=
notif
.
decode
(
"utf-8"
).
rsplit
(
":"
,
1
)
req_id
,
tp_
size
=
notif
.
decode
(
"utf-8"
).
rsplit
(
":"
,
1
)
if
(
if
(
req_id
not
in
self
.
_reqs_to_send
req_id
not
in
self
.
_reqs_to_send
and
req_id
not
in
self
.
_reqs_to_process
and
req_id
not
in
self
.
_reqs_to_process
...
@@ -1853,9 +1916,22 @@ class NixlConnectorWorker:
...
@@ -1853,9 +1916,22 @@ class NixlConnectorWorker:
)
)
continue
continue
# NOTE: `tp_ratio` is the opposite when swapping local<>remote
n_consumers
=
int
(
tp_size
)
tp_ratio
=
self
.
kv_topo
.
tp_ratio
(
n_consumers
)
# Number of reads *per producer* to wait for.
# When remote D TP > local P TP we expect `tp_ratio` reads.
consumers_per_producer
=
(
-
tp_ratio
if
n_consumers
>
self
.
world_size
else
1
)
self
.
consumer_notification_counts_by_req
[
req_id
]
+=
1
self
.
consumer_notification_counts_by_req
[
req_id
]
+=
1
# Wait all consumers (D) to be done reading before freeing.
# Wait all consumers (D) to be done reading before freeing.
if
self
.
consumer_notification_counts_by_req
[
req_id
]
==
int
(
tp_ratio
):
if
(
self
.
consumer_notification_counts_by_req
[
req_id
]
==
consumers_per_producer
):
notified_req_ids
.
add
(
req_id
)
notified_req_ids
.
add
(
req_id
)
del
self
.
consumer_notification_counts_by_req
[
req_id
]
del
self
.
consumer_notification_counts_by_req
[
req_id
]
self
.
_reqs_to_process
.
remove
(
req_id
)
self
.
_reqs_to_process
.
remove
(
req_id
)
...
@@ -1872,7 +1948,7 @@ class NixlConnectorWorker:
...
@@ -1872,7 +1948,7 @@ class NixlConnectorWorker:
"""
"""
done_req_ids
:
set
[
str
]
=
set
()
done_req_ids
:
set
[
str
]
=
set
()
for
req_id
,
handles
in
list
(
transfers
.
items
()):
for
req_id
,
handles
in
list
(
transfers
.
items
()):
in_progress
=
False
in_progress
=
[]
for
handle
in
handles
:
for
handle
in
handles
:
try
:
try
:
xfer_state
=
self
.
nixl_wrapper
.
check_xfer_state
(
handle
)
xfer_state
=
self
.
nixl_wrapper
.
check_xfer_state
(
handle
)
...
@@ -1882,7 +1958,7 @@ class NixlConnectorWorker:
...
@@ -1882,7 +1958,7 @@ class NixlConnectorWorker:
self
.
xfer_stats
.
record_transfer
(
res
)
self
.
xfer_stats
.
record_transfer
(
res
)
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
elif
xfer_state
==
"PROC"
:
elif
xfer_state
==
"PROC"
:
in_progress
=
True
in_progress
.
append
(
handle
)
continue
continue
else
:
else
:
logger
.
error
(
logger
.
error
(
...
@@ -1892,7 +1968,6 @@ class NixlConnectorWorker:
...
@@ -1892,7 +1968,6 @@ class NixlConnectorWorker:
xfer_state
,
xfer_state
,
)
)
self
.
_handle_failed_transfer
(
req_id
,
handle
)
self
.
_handle_failed_transfer
(
req_id
,
handle
)
in_progress
=
False
except
Exception
:
except
Exception
:
logger
.
exception
(
logger
.
exception
(
"NIXL transfer exception for request %s. "
"NIXL transfer exception for request %s. "
...
@@ -1900,11 +1975,13 @@ class NixlConnectorWorker:
...
@@ -1900,11 +1975,13 @@ class NixlConnectorWorker:
req_id
,
req_id
,
)
)
self
.
_handle_failed_transfer
(
req_id
,
handle
)
self
.
_handle_failed_transfer
(
req_id
,
handle
)
in_progress
=
False
if
not
in_progress
:
if
not
in_progress
:
# Only report request as completed when all transfers are done.
done_req_ids
.
add
(
req_id
)
done_req_ids
.
add
(
req_id
)
del
transfers
[
req_id
]
del
transfers
[
req_id
]
else
:
transfers
[
req_id
]
=
in_progress
return
done_req_ids
return
done_req_ids
def
_handle_failed_transfer
(
self
,
req_id
:
str
,
handle
:
int
):
def
_handle_failed_transfer
(
self
,
req_id
:
str
,
handle
:
int
):
...
@@ -1982,18 +2059,62 @@ class NixlConnectorWorker:
...
@@ -1982,18 +2059,62 @@ class NixlConnectorWorker:
def
_read_blocks_for_req
(
self
,
req_id
:
str
,
meta
:
ReqMeta
):
def
_read_blocks_for_req
(
self
,
req_id
:
str
,
meta
:
ReqMeta
):
assert
meta
.
remote
is
not
None
assert
meta
.
remote
is
not
None
logger
.
debug
(
remote_ranks
=
self
.
kv_topo
.
get_target_remote_ranks_from_engine_id
(
"Remote agent %s available, calling _read_blocks for req %s"
,
meta
.
remote
.
engine_id
meta
.
remote
.
engine_id
,
req_id
,
)
self
.
_read_blocks
(
request_id
=
req_id
,
dst_engine_id
=
meta
.
remote
.
engine_id
,
remote_request_id
=
meta
.
remote
.
request_id
,
local_block_ids
=
meta
.
local_physical_block_ids
,
remote_block_ids
=
meta
.
remote
.
block_ids
,
)
)
tp_ratio
=
self
.
kv_topo
.
tp_ratio_from_engine_id
(
meta
.
remote
.
engine_id
)
# D may have to perform multiple reads from different remote ranks.
for
i
,
remote_rank
in
enumerate
(
remote_ranks
):
if
self
.
use_mla
and
tp_ratio
<
0
and
i
>
0
:
# MLA opt: when P TP > D TP, only a single read is executed for
# the first remote rank (cache is duplicated)..
break
remote_block_size
=
self
.
kv_topo
.
remote_block_size
[
meta
.
remote
.
engine_id
]
logger
.
debug
(
"Remote agent %s available, calling _read_blocks"
" on remote rank %s with remote block size %s for req %s"
,
meta
.
remote
.
engine_id
,
remote_rank
,
remote_block_size
,
req_id
,
)
# Get side handles.
if
tp_ratio
<
0
and
not
self
.
use_mla
:
assert
remote_block_size
==
self
.
block_size
# Remote tp_size > local tp_size: we must perform multiple
# reads. Get the memory chunk onto which we will write to.
local_xfer_side_handle
=
self
.
src_xfer_handles_by_tp_ratio
[
tp_ratio
][
i
]
else
:
# Single read from remote, we write to the whole memory region.
# Also handle remote block size different from local block size.
local_xfer_side_handle
=
self
.
src_xfer_handles_by_block_size
[
remote_block_size
]
# Destination handle: remote_engine_id -> remote_rank -> handle.
remote_xfer_side_handle
=
self
.
dst_xfer_side_handles
[
meta
.
remote
.
engine_id
][
remote_rank
]
self
.
_read_blocks
(
request_id
=
req_id
,
dst_engine_id
=
meta
.
remote
.
engine_id
,
remote_request_id
=
meta
.
remote
.
request_id
,
local_block_ids
=
meta
.
local_physical_block_ids
,
remote_block_ids
=
meta
.
remote
.
block_ids
,
remote_rank
=
remote_rank
,
local_xfer_side_handle
=
local_xfer_side_handle
,
remote_xfer_side_handle
=
remote_xfer_side_handle
,
)
if
self
.
use_mla
and
tp_ratio
<
0
:
# ..but we still need to notify the other remote ranks that we
# have the blocks we need so they can update the request state.
notif_id
=
f
"
{
req_id
}
:
{
self
.
world_size
}
"
.
encode
()
remote_agents
=
self
.
_remote_agents
[
meta
.
remote
.
engine_id
]
for
rank_to_notify
,
agent
in
remote_agents
.
items
():
if
rank_to_notify
!=
remote_rank
:
self
.
nixl_wrapper
.
send_notif
(
agent
,
notif_msg
=
notif_id
)
def
_read_blocks
(
def
_read_blocks
(
self
,
self
,
...
@@ -2002,7 +2123,14 @@ class NixlConnectorWorker:
...
@@ -2002,7 +2123,14 @@ class NixlConnectorWorker:
dst_engine_id
:
str
,
dst_engine_id
:
str
,
request_id
:
str
,
request_id
:
str
,
remote_request_id
:
str
,
remote_request_id
:
str
,
remote_rank
:
int
,
local_xfer_side_handle
:
int
,
remote_xfer_side_handle
:
int
,
):
):
"""
Post a READ point-to-point xfer request from a single local worker to
a single remote worker.
"""
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
dst_engine_id
)
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
dst_engine_id
)
if
block_size_ratio
>
1
:
if
block_size_ratio
>
1
:
local_block_ids
=
self
.
get_mapped_blocks
(
local_block_ids
=
self
.
get_mapped_blocks
(
...
@@ -2031,18 +2159,14 @@ class NixlConnectorWorker:
...
@@ -2031,18 +2159,14 @@ class NixlConnectorWorker:
# saturate IB with heterogeneous TP sizes. We should remove the staging
# saturate IB with heterogeneous TP sizes. We should remove the staging
# blocks until we are ready.
# blocks until we are ready.
# Number of D TP workers that will read from dst P. Propagate
tp_rati
o
# Number of D TP workers that will read from dst P. Propagate
inf
o
# on notification so that dst worker can wait before freeing blocks.
# on notification so that dst worker can wait before freeing blocks.
tp_ratio
=
self
.
kv_topo
.
tp_ratio_from_engine_id
(
dst_engine_id
)
notif_id
=
f
"
{
remote_request_id
}
:
{
self
.
world_size
}
"
.
encode
()
notif_id
=
f
"
{
remote_request_id
}
:
{
tp_ratio
}
"
.
encode
()
# Full prefix cache hit: do not need to read remote blocks,
# Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need.
# just notify P worker that we have the blocks we need.
num_local_blocks
=
len
(
local_block_ids
)
num_local_blocks
=
len
(
local_block_ids
)
if
num_local_blocks
==
0
:
if
num_local_blocks
==
0
:
remote_rank
=
self
.
kv_topo
.
get_target_remote_rank_from_engine_id
(
dst_engine_id
)
agent_name
=
self
.
_remote_agents
[
dst_engine_id
][
remote_rank
]
agent_name
=
self
.
_remote_agents
[
dst_engine_id
][
remote_rank
]
try
:
try
:
self
.
nixl_wrapper
.
send_notif
(
agent_name
,
notif_msg
=
notif_id
)
self
.
nixl_wrapper
.
send_notif
(
agent_name
,
notif_msg
=
notif_id
)
...
@@ -2062,13 +2186,6 @@ class NixlConnectorWorker:
...
@@ -2062,13 +2186,6 @@ class NixlConnectorWorker:
if
num_local_blocks
<
num_remote_blocks
:
if
num_local_blocks
<
num_remote_blocks
:
remote_block_ids
=
remote_block_ids
[
-
num_local_blocks
:]
remote_block_ids
=
remote_block_ids
[
-
num_local_blocks
:]
# Get side handles.
remote_block_size
=
self
.
kv_topo
.
remote_block_size
[
dst_engine_id
]
local_xfer_side_handle
=
self
.
src_xfer_side_handles
.
get
(
remote_block_size
,
self
.
src_xfer_side_handle
)
remote_xfer_side_handle
=
self
.
dst_xfer_side_handles
[
dst_engine_id
]
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
# workers will issue xfers to parts of the P worker remote kv caches.
# workers will issue xfers to parts of the P worker remote kv caches.
...
@@ -2230,7 +2347,7 @@ class NixlConnectorWorker:
...
@@ -2230,7 +2347,7 @@ class NixlConnectorWorker:
block_ids_np
,
self
.
_physical_blocks_per_logical_kv_block
,
block_arange
block_ids_np
,
self
.
_physical_blocks_per_logical_kv_block
,
block_arange
).
tolist
()
).
tolist
()
def
get_backend_aware_kv_block_len
(
self
,
layer_idx
:
int
):
def
get_backend_aware_kv_block_len
(
self
,
layer_idx
:
int
)
->
int
:
"""
"""
Get the block length for one K/V element (K and V have the same size).
Get the block length for one K/V element (K and V have the same size).
...
@@ -2276,11 +2393,16 @@ class NixlConnectorWorker:
...
@@ -2276,11 +2393,16 @@ class NixlConnectorWorker:
for
handle
in
handles
:
for
handle
in
handles
:
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
self
.
_recving_transfers
.
clear
()
self
.
_recving_transfers
.
clear
()
if
self
.
src_xfer_side_handle
:
for
handle
in
self
.
src_xfer_handles_by_block_size
.
values
():
self
.
nixl_wrapper
.
release_dlist_handle
(
self
.
src_xfer_side_handle
)
self
.
nixl_wrapper
.
release_dlist_handle
(
handle
)
self
.
src_xfer_side_handle
=
0
self
.
src_xfer_handles_by_block_size
.
clear
()
for
dst_xfer_side_handle
in
self
.
dst_xfer_side_handles
.
values
():
for
handles
in
self
.
src_xfer_handles_by_tp_ratio
.
values
():
self
.
nixl_wrapper
.
release_dlist_handle
(
dst_xfer_side_handle
)
for
handle
in
handles
:
self
.
nixl_wrapper
.
release_dlist_handle
(
handle
)
self
.
src_xfer_handles_by_tp_ratio
.
clear
()
for
dst_xfer_side_handles
in
self
.
dst_xfer_side_handles
.
values
():
for
dst_xfer_side_handle
in
dst_xfer_side_handles
.
values
():
self
.
nixl_wrapper
.
release_dlist_handle
(
dst_xfer_side_handle
)
self
.
dst_xfer_side_handles
.
clear
()
self
.
dst_xfer_side_handles
.
clear
()
for
remote_agents
in
self
.
_remote_agents
.
values
():
for
remote_agents
in
self
.
_remote_agents
.
values
():
for
agent_name
in
remote_agents
.
values
():
for
agent_name
in
remote_agents
.
values
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment