Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
949a6a19
Unverified
Commit
949a6a19
authored
Dec 05, 2025
by
Mark McLoughlin
Committed by
GitHub
Dec 05, 2025
Browse files
[NIXL] Add compatibility checking to NIXL KV connector handshake (#29503)
Signed-off-by:
Mark McLoughlin
<
markmc@redhat.com
>
parent
2c174420
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
380 additions
and
26 deletions
+380
-26
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+213
-11
tests/v1/kv_connector/unit/utils.py
tests/v1/kv_connector/unit/utils.py
+8
-2
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+159
-13
No files found.
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
949a6a19
...
@@ -9,8 +9,10 @@ import textwrap
...
@@ -9,8 +9,10 @@ import textwrap
import
time
import
time
import
uuid
import
uuid
from
collections
import
defaultdict
from
collections
import
defaultdict
from
unittest.mock
import
patch
from
typing
import
Any
from
unittest.mock
import
MagicMock
,
patch
import
msgspec
import
pytest
import
pytest
import
ray
import
ray
import
torch
import
torch
...
@@ -18,6 +20,7 @@ import torch
...
@@ -18,6 +20,7 @@ import torch
from
vllm
import
LLM
from
vllm
import
LLM
from
vllm.config
import
KVTransferConfig
from
vllm.config
import
KVTransferConfig
from
vllm.distributed.kv_transfer.kv_connector.utils
import
KVOutputAggregator
from
vllm.distributed.kv_transfer.kv_connector.utils
import
KVOutputAggregator
from
vllm.distributed.kv_transfer.kv_connector.v1
import
nixl_connector
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
KVConnectorStats
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
KVConnectorStats
from
vllm.distributed.kv_transfer.kv_connector.v1.multi_connector
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.multi_connector
import
(
MultiKVConnectorStats
,
MultiKVConnectorStats
,
...
@@ -29,7 +32,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
...
@@ -29,7 +32,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorMetadata
,
NixlConnectorMetadata
,
NixlConnectorScheduler
,
NixlConnectorScheduler
,
NixlConnectorWorker
,
NixlConnectorWorker
,
NixlHandshakePayload
,
NixlKVConnectorStats
,
NixlKVConnectorStats
,
compute_nixl_compatibility_hash
,
)
)
from
vllm.distributed.kv_transfer.kv_transfer_state
import
(
from
vllm.distributed.kv_transfer.kv_transfer_state
import
(
ensure_kv_transfer_shutdown
,
ensure_kv_transfer_shutdown
,
...
@@ -317,13 +322,19 @@ def test_kv_transfer_handshake(dist_init):
...
@@ -317,13 +322,19 @@ def test_kv_transfer_handshake(dist_init):
}
}
prefill_connector
.
register_kv_caches
(
kv_caches
)
prefill_connector
.
register_kv_caches
(
kv_caches
)
# Simulate EngineCore initialization that would
# Simulate EngineCore initialization that would gather connector
# gather connector metadata from all workers, the scheduler connector
# metadata from all workers
# expects metadata to be in dict[int, KVConnectorHandshakeMetadata],
metadata
=
prefill_connector
.
get_handshake_metadata
()
# where the first key is the dp_rank, the second key is the tp_rank.
metadata
=
{
0
:
prefill_connector
.
get_handshake_metadata
()}
# metadata is a NixlHandshakePayload, decode it to get NixlAgentMetadata
decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlAgentMetadata
)
expected_agent_metadata
=
decoder
.
decode
(
metadata
.
agent_metadata_bytes
)
# The scheduler connector expects metadata to be in
# dict[int, KVConnectorHandshakeMetadata], where the first key is
# the dp_rank, the second key is the tp_rank.
scheduler_connector
=
scheduler
.
get_kv_connector
()
scheduler_connector
=
scheduler
.
get_kv_connector
()
scheduler_connector
.
set_xfer_handshake_metadata
(
metadata
)
scheduler_connector
.
set_xfer_handshake_metadata
(
{
0
:
metadata
}
)
# Simulate a request that finishes prefill, which returns
# Simulate a request that finishes prefill, which returns
# corresponding NixlConnectorMetadata for decode instance.
# corresponding NixlConnectorMetadata for decode instance.
...
@@ -362,9 +373,9 @@ def test_kv_transfer_handshake(dist_init):
...
@@ -362,9 +373,9 @@ def test_kv_transfer_handshake(dist_init):
)
)
received_metadata
=
mock_add_remote_agent
.
call_args
.
args
received_metadata
=
mock_add_remote_agent
.
call_args
.
args
assert
received_metadata
[
0
]
==
expected_agent_metadata
assert
received_metadata
[
1
]
==
0
# remote_tp_rank
assert
received_metadata
[
1
]
==
0
# remote_tp_rank
assert
received_metadata
[
2
]
==
1
# remote_tp_size
assert
received_metadata
[
2
]
==
1
# remote_tp_size
assert
metadata
[
0
]
==
received_metadata
[
0
]
# Need to shutdown the background thread to release NIXL side channel port
# Need to shutdown the background thread to release NIXL side channel port
scheduler_connector
.
shutdown
()
scheduler_connector
.
shutdown
()
...
@@ -403,7 +414,6 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
...
@@ -403,7 +414,6 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
device_id
=
0
,
device_id
=
0
,
num_blocks
=
1
,
num_blocks
=
1
,
block_lens
=
self
.
block_len_per_layer
,
block_lens
=
self
.
block_len_per_layer
,
attn_backend_name
=
self
.
backend_name
,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
# is started. We mock HND here.
kv_cache_layout
=
"HND"
,
kv_cache_layout
=
"HND"
,
...
@@ -651,7 +661,6 @@ class TestNixlHandshake:
...
@@ -651,7 +661,6 @@ class TestNixlHandshake:
device_id
=
0
,
device_id
=
0
,
num_blocks
=
1
,
num_blocks
=
1
,
block_lens
=
worker
.
block_len_per_layer
,
block_lens
=
worker
.
block_len_per_layer
,
attn_backend_name
=
worker
.
backend_name
,
kv_cache_layout
=
mismatched_layout
,
kv_cache_layout
=
mismatched_layout
,
block_size
=
worker
.
block_size
,
block_size
=
worker
.
block_size
,
)
)
...
@@ -706,7 +715,6 @@ class TestNixlHandshake:
...
@@ -706,7 +715,6 @@ class TestNixlHandshake:
num_blocks
=
1
,
num_blocks
=
1
,
# prefill TP=1, decode TP=2, remote block_lens is double to local
# prefill TP=1, decode TP=2, remote block_lens is double to local
block_lens
=
[
i
*
2
for
i
in
worker
.
block_len_per_layer
],
block_lens
=
[
i
*
2
for
i
in
worker
.
block_len_per_layer
],
attn_backend_name
=
worker
.
backend_name
,
kv_cache_layout
=
"HND"
,
kv_cache_layout
=
"HND"
,
block_size
=
worker
.
block_size
,
block_size
=
worker
.
block_size
,
)
)
...
@@ -1168,6 +1176,9 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
...
@@ -1168,6 +1176,9 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
mock_wrapper_instance
=
mock_nixl_wrapper
.
return_value
mock_wrapper_instance
=
mock_nixl_wrapper
.
return_value
connector
.
connector_worker
.
nixl_wrapper
=
mock_wrapper_instance
connector
.
connector_worker
.
nixl_wrapper
=
mock_wrapper_instance
# Appease NixlHandshakePayload encoding with some bytes
mock_wrapper_instance
.
get_agent_metadata
.
return_value
=
b
"fake_agent_metadata"
# Reassure the shutdown() check that the thread is terminated
# Reassure the shutdown() check that the thread is terminated
mock_thread
.
return_value
.
is_alive
.
return_value
=
False
mock_thread
.
return_value
.
is_alive
.
return_value
=
False
...
@@ -1534,3 +1545,194 @@ def test_transfer_setup_failure_returns_finished(dist_init):
...
@@ -1534,3 +1545,194 @@ def test_transfer_setup_failure_returns_finished(dist_init):
# ensure request appears in get_finished
# ensure request appears in get_finished
_
,
done_recving
=
connector
.
get_finished
(
finished_req_ids
=
set
())
_
,
done_recving
=
connector
.
get_finished
(
finished_req_ids
=
set
())
assert
request_id
in
done_recving
assert
request_id
in
done_recving
@
pytest
.
mark
.
parametrize
(
"mismatch_type,config_overrides,version_override,should_fail,enforce_handshake_compat"
,
[
(
"vllm_version"
,
{},
{
"vllm_version"
:
"0.6.1"
},
True
,
True
),
(
"nixl_connector_version"
,
{},
{
"connector_version"
:
37
},
True
,
True
),
(
"model_name"
,
{
"model"
:
"facebook/opt-350m"
},
{},
True
,
True
),
(
"dtype"
,
{
"dtype"
:
"bfloat16"
},
{},
True
,
True
),
(
"cache_dtype"
,
{
"cache_dtype"
:
"fp8"
},
{},
True
,
True
),
(
"num_kv_heads"
,
{
"hf_overrides"
:
{
"num_key_value_heads"
:
8
}},
{},
True
,
True
),
(
"num_hidden_layers"
,
{
"hf_overrides"
:
{
"num_hidden_layers"
:
24
}},
{},
True
,
True
,
),
(
"hidden_size"
,
{
"hf_overrides"
:
{
"hidden_size"
:
1536
}},
{},
True
,
True
),
(
"block_size"
,
{
"block_size"
:
8
},
{},
False
,
True
),
(
"matching_config"
,
{},
{},
False
,
True
),
(
"escape_hatch"
,
{
"model"
:
"facebook/opt-350m"
},
{},
False
,
False
),
],
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
,
)
def
test_compatibility_hash_validation
(
dist_init
,
mismatch_type
,
config_overrides
,
version_override
,
should_fail
,
enforce_handshake_compat
,
):
"""
Test NIXL compatibility hash validation during handshake.
Parameters:
mismatch_type: description of what is being tested
config_overrides: dict of config to override for the remote instance
version_override: version dict e.g. {"vllm_version": "0.6.1"}
should_fail: whether the handshake should fail
enforce_handshake_compat: whether to enforce compatibility checking
"""
local_vllm_config
=
create_vllm_config
(
model
=
"facebook/opt-125m"
,
block_size
=
16
,
kv_connector_extra_config
=
{
"enforce_handshake_compat"
:
enforce_handshake_compat
},
)
decode_connector
=
NixlConnector
(
local_vllm_config
,
KVConnectorRole
.
WORKER
)
decode_worker
=
decode_connector
.
connector_worker
remote_config_params
:
dict
[
str
,
Any
]
=
{
"model"
:
"facebook/opt-125m"
,
"block_size"
:
16
,
**
config_overrides
,
}
remote_vllm_config
=
create_vllm_config
(
**
remote_config_params
)
with
contextlib
.
ExitStack
()
as
stack
:
if
"vllm_version"
in
version_override
:
stack
.
enter_context
(
patch
(
"vllm.__version__"
,
version_override
[
"vllm_version"
])
)
elif
"connector_version"
in
version_override
:
stack
.
enter_context
(
patch
.
object
(
nixl_connector
,
"NIXL_CONNECTOR_VERSION"
,
version_override
[
"connector_version"
],
)
)
remote_hash
=
compute_nixl_compatibility_hash
(
remote_vllm_config
,
decode_worker
.
backend_name
)
prefill_block_size
=
config_overrides
.
get
(
"block_size"
,
16
)
prefill_metadata
=
NixlAgentMetadata
(
engine_id
=
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
agent_metadata
=
FakeNixlWrapper
.
AGENT_METADATA
,
kv_caches_base_addr
=
[
0
],
device_id
=
0
,
num_blocks
=
1
,
block_lens
=
[
4096
*
prefill_block_size
],
# slot_size * block_size
kv_cache_layout
=
"HND"
,
block_size
=
prefill_block_size
,
)
handshake_payload
=
NixlHandshakePayload
(
compatibility_hash
=
remote_hash
,
agent_metadata_bytes
=
msgspec
.
msgpack
.
encode
(
prefill_metadata
),
)
# Mock ZMQ socket to return our handshake payload
mock_socket
=
MagicMock
()
mock_socket
.
recv
.
return_value
=
msgspec
.
msgpack
.
encode
(
handshake_payload
)
# Mock add_remote_agent to avoid actual NIXL operations
# Patch zmq_ctx to return our mock socket
with
(
patch
.
object
(
decode_worker
,
"add_remote_agent"
,
return_value
=
"fake_agent"
),
patch
.
object
(
nixl_connector
,
"zmq_ctx"
)
as
mock_zmq_ctx
,
):
mock_zmq_ctx
.
return_value
.
__enter__
.
return_value
=
mock_socket
if
should_fail
:
with
pytest
.
raises
(
RuntimeError
,
match
=
"compatibility hash mismatch"
):
decode_worker
.
_nixl_handshake
(
host
=
"localhost"
,
port
=
1234
,
remote_tp_size
=
1
,
expected_engine_id
=
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
)
else
:
result
=
decode_worker
.
_nixl_handshake
(
host
=
"localhost"
,
port
=
1234
,
remote_tp_size
=
1
,
expected_engine_id
=
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
)
# Verify handshake returned agent mapping
assert
isinstance
(
result
,
dict
)
assert
len
(
result
)
==
1
@
pytest
.
mark
.
parametrize
(
"error_scenario"
,
[
"handshake_decode_error"
,
"handshake_validation_error"
,
"metadata_decode_error"
,
"metadata_validation_error"
,
],
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
,
)
def
test_handshake_decode_errors
(
dist_init
,
error_scenario
):
"""
Test that msgspec decode errors are properly handled during handshake.
Tests both DecodeError and ValidationError for both decoders:
- NixlHandshakePayload decoder
- NixlAgentMetadata decoder
"""
local_vllm_config
=
create_vllm_config
(
model
=
"facebook/opt-125m"
,
block_size
=
16
,
)
decode_connector
=
NixlConnector
(
local_vllm_config
,
KVConnectorRole
.
WORKER
)
decode_worker
=
decode_connector
.
connector_worker
if
error_scenario
==
"handshake_decode_error"
:
msg_bytes
=
b
"this is not valid msgpack data"
elif
error_scenario
==
"handshake_validation_error"
:
msg_bytes
=
msgspec
.
msgpack
.
encode
({
"wrong_field"
:
"value"
})
elif
error_scenario
==
"metadata_decode_error"
:
valid_handshake
=
NixlHandshakePayload
(
compatibility_hash
=
decode_worker
.
compat_hash
,
agent_metadata_bytes
=
b
"invalid msgpack for metadata"
,
)
msg_bytes
=
msgspec
.
msgpack
.
encode
(
valid_handshake
)
elif
error_scenario
==
"metadata_validation_error"
:
valid_handshake
=
NixlHandshakePayload
(
compatibility_hash
=
decode_worker
.
compat_hash
,
agent_metadata_bytes
=
msgspec
.
msgpack
.
encode
({
"missing"
:
"fields"
}),
)
msg_bytes
=
msgspec
.
msgpack
.
encode
(
valid_handshake
)
else
:
raise
AssertionError
(
f
"
{
error_scenario
}
not a valid scenario"
)
mock_socket
=
MagicMock
()
mock_socket
.
recv
.
return_value
=
msg_bytes
with
(
patch
.
object
(
decode_worker
,
"add_remote_agent"
,
return_value
=
"fake_agent"
),
patch
.
object
(
nixl_connector
,
"zmq_ctx"
)
as
mock_zmq_ctx
,
):
mock_zmq_ctx
.
return_value
.
__enter__
.
return_value
=
mock_socket
with
pytest
.
raises
(
RuntimeError
):
decode_worker
.
_nixl_handshake
(
host
=
"localhost"
,
port
=
1234
,
remote_tp_size
=
1
,
expected_engine_id
=
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
)
tests/v1/kv_connector/unit/utils.py
View file @
949a6a19
...
@@ -90,13 +90,18 @@ def create_vllm_config(
...
@@ -90,13 +90,18 @@ def create_vllm_config(
max_model_len
:
int
=
10000
,
max_model_len
:
int
=
10000
,
enable_chunked_prefill
:
bool
=
True
,
enable_chunked_prefill
:
bool
=
True
,
enable_permute_local_kv
:
bool
=
False
,
enable_permute_local_kv
:
bool
=
False
,
kv_connector_extra_config
:
dict
[
str
,
Any
]
|
None
=
None
,
dtype
:
str
=
"float16"
,
cache_dtype
:
str
=
"auto"
,
hf_overrides
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
VllmConfig
:
)
->
VllmConfig
:
"""Initialize VllmConfig For Testing."""
"""Initialize VllmConfig For Testing."""
model_config
=
ModelConfig
(
model_config
=
ModelConfig
(
model
=
model
,
model
=
model
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
dtype
=
"float16"
,
dtype
=
dtype
,
seed
=
42
,
seed
=
42
,
hf_overrides
=
hf_overrides
or
{},
)
)
scheduler_config
=
SchedulerConfig
(
scheduler_config
=
SchedulerConfig
(
max_num_seqs
=
max_num_seqs
,
max_num_seqs
=
max_num_seqs
,
...
@@ -110,13 +115,14 @@ def create_vllm_config(
...
@@ -110,13 +115,14 @@ def create_vllm_config(
block_size
=
block_size
,
block_size
=
block_size
,
gpu_memory_utilization
=
0.9
,
gpu_memory_utilization
=
0.9
,
swap_space
=
0
,
swap_space
=
0
,
cache_dtype
=
"auto"
,
cache_dtype
=
cache_dtype
,
enable_prefix_caching
=
True
,
enable_prefix_caching
=
True
,
)
)
kv_transfer_config
=
KVTransferConfig
(
kv_transfer_config
=
KVTransferConfig
(
kv_connector
=
"NixlConnector"
,
kv_connector
=
"NixlConnector"
,
kv_role
=
"kv_both"
,
kv_role
=
"kv_both"
,
enable_permute_local_kv
=
enable_permute_local_kv
,
enable_permute_local_kv
=
enable_permute_local_kv
,
kv_connector_extra_config
=
kv_connector_extra_config
or
{},
)
)
return
VllmConfig
(
return
VllmConfig
(
scheduler_config
=
scheduler_config
,
scheduler_config
=
scheduler_config
,
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
949a6a19
...
@@ -59,6 +59,21 @@ Transfer = tuple[int, float] # (xfer_handle, start_time)
...
@@ -59,6 +59,21 @@ Transfer = tuple[int, float] # (xfer_handle, start_time)
EngineId
=
str
EngineId
=
str
ReqId
=
str
ReqId
=
str
#
# NIXL Connector Version
#
# Increment this version whenever there is an incompatible change to:
# - NixlAgentMetadata schema
# - kv_transfer_params schema or semantics
# - NIXL transfer protocol or wire format
# - KV cache memory layout or block organization
# - Any other change that breaks P/D interoperability
#
# Version History:
# 1: Initial version with compatibility checking
#
NIXL_CONNECTOR_VERSION
:
int
=
1
GET_META_MSG
=
b
"get_meta_msg"
GET_META_MSG
=
b
"get_meta_msg"
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -97,18 +112,95 @@ _NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
...
@@ -97,18 +112,95 @@ _NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
@
dataclass
@
dataclass
class
NixlAgentMetadata
(
KVConnectorHandshakeMetadata
)
:
class
NixlAgentMetadata
:
engine_id
:
str
engine_id
:
str
agent_metadata
:
bytes
agent_metadata
:
bytes
kv_caches_base_addr
:
list
[
int
]
kv_caches_base_addr
:
list
[
int
]
device_id
:
int
device_id
:
int
num_blocks
:
int
num_blocks
:
int
block_lens
:
list
[
int
]
block_lens
:
list
[
int
]
attn_backend_name
:
str
kv_cache_layout
:
str
kv_cache_layout
:
str
block_size
:
int
block_size
:
int
@
dataclass
class
NixlHandshakePayload
(
KVConnectorHandshakeMetadata
):
"""
Wrapper for NIXL handshake sent over the wire.
Enables two-phase decoding for graceful compatibility checking:
1. Decode NixlHandshakePayload to get compatibility_hash
2. Compute local hash and compare
3. Only if hashes match, decode agent_metadata_bytes
This prevents decoder errors when NixlAgentMetadata schema is
incompatible, allowing graceful failure with clear error message.
"""
compatibility_hash
:
str
agent_metadata_bytes
:
bytes
# NixlAgentMetadata encoded
def
compute_nixl_compatibility_hash
(
vllm_config
:
VllmConfig
,
attn_backend_name
:
str
)
->
str
:
"""
Compute compatibility hash for NIXL KV transfer.
Hash only the factors that affect whether two NIXL instances can
successfully transfer KV cache data.
Factors included:
- vLLM version and NIXL connector version
- Model architecture (name, dtype, KV heads, layers)
- KV cache format (dtype, sliding window)
- Attention backend
Note: Factors like tensor_parallel_size, block_size, and kv_cache_layout
are validated at runtime in _validate_remote_agent_handshake and are not
included in this hash to support heterogeneous deployments.
Note - the set of factors are likely to evolve significantly over
time to be more or less permissive.
Returns:
SHA-256 hex digest
"""
from
vllm
import
__version__
as
vllm_version
from
vllm.config.utils
import
hash_factors
model_config
=
vllm_config
.
model_config
cache_config
=
vllm_config
.
cache_config
factors
=
{
# Version compatibility
"vllm_version"
:
vllm_version
,
"nixl_connector_version"
:
NIXL_CONNECTOR_VERSION
,
# Model architecture - affects KV cache shape
"model"
:
model_config
.
model
,
"dtype"
:
str
(
model_config
.
dtype
),
"num_kv_heads"
:
model_config
.
get_total_num_kv_heads
(),
"head_size"
:
model_config
.
get_head_size
(),
"num_hidden_layers"
:
model_config
.
get_total_num_hidden_layers
(),
# Attention backend and KV cache dtype affect memory layout
"attn_backend_name"
:
attn_backend_name
,
"cache_dtype"
:
str
(
cache_config
.
cache_dtype
),
}
compat_hash
=
hash_factors
(
factors
)
logger
.
info
(
"NIXL compatibility hash: %s (model=%s, dtype=%s, num_kv_heads=%d, "
"cache_dtype=%s, attn_backend=%s)"
,
compat_hash
,
factors
[
"model"
],
factors
[
"dtype"
],
factors
[
"num_kv_heads"
],
factors
[
"cache_dtype"
],
attn_backend_name
,
)
return
compat_hash
@
dataclass
@
dataclass
class
ReqMeta
:
class
ReqMeta
:
local_block_ids
:
list
[
int
]
local_block_ids
:
list
[
int
]
...
@@ -396,14 +488,14 @@ class NixlConnectorScheduler:
...
@@ -396,14 +488,14 @@ class NixlConnectorScheduler:
encoded_data
:
dict
[
int
,
bytes
]
=
{}
encoded_data
:
dict
[
int
,
bytes
]
=
{}
encoder
=
msgspec
.
msgpack
.
Encoder
()
encoder
=
msgspec
.
msgpack
.
Encoder
()
for
tp_rank
,
rank_metadata
in
metadata
.
items
():
for
tp_rank
,
rank_metadata
in
metadata
.
items
():
if
not
isinstance
(
rank_metadata
,
Nixl
AgentMetadata
):
if
not
isinstance
(
rank_metadata
,
Nixl
HandshakePayload
):
raise
ValueError
(
raise
ValueError
(
"NixlConnectorScheduler expects Nixl
AgentMetadata
for "
"NixlConnectorScheduler expects Nixl
HandshakePayload
for "
"handshake metadata."
"handshake metadata."
)
)
encoded_data
[
tp_rank
]
=
encoder
.
encode
(
rank_metadata
)
encoded_data
[
tp_rank
]
=
encoder
.
encode
(
rank_metadata
)
logger
.
debug
(
logger
.
debug
(
"Tp rank %d: encoded Nixl
AgentMetadata
size: %s bytes"
,
"Tp rank %d: encoded Nixl
HandshakePayload
size: %s bytes"
,
tp_rank
,
tp_rank
,
str
(
len
(
encoded_data
[
tp_rank
])),
str
(
len
(
encoded_data
[
tp_rank
])),
)
)
...
@@ -794,7 +886,7 @@ class NixlConnectorWorker:
...
@@ -794,7 +886,7 @@ class NixlConnectorWorker:
self
.
_failed_recv_reqs
:
set
[
ReqId
]
=
set
()
self
.
_failed_recv_reqs
:
set
[
ReqId
]
=
set
()
# Handshake metadata of this worker for NIXL transfers.
# Handshake metadata of this worker for NIXL transfers.
self
.
xfer_handshake_metadata
:
Nixl
AgentMetadata
|
None
=
None
self
.
xfer_handshake_metadata
:
Nixl
HandshakePayload
|
None
=
None
# Background thread for initializing new NIXL handshakes.
# Background thread for initializing new NIXL handshakes.
self
.
_handshake_initiation_executor
=
ThreadPoolExecutor
(
self
.
_handshake_initiation_executor
=
ThreadPoolExecutor
(
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
...
@@ -829,6 +921,13 @@ class NixlConnectorWorker:
...
@@ -829,6 +921,13 @@ class NixlConnectorWorker:
logger
.
debug
(
"Detected attention backend %s"
,
self
.
backend_name
)
logger
.
debug
(
"Detected attention backend %s"
,
self
.
backend_name
)
logger
.
debug
(
"Detected kv cache layout %s"
,
self
.
kv_cache_layout
)
logger
.
debug
(
"Detected kv cache layout %s"
,
self
.
kv_cache_layout
)
self
.
compat_hash
=
compute_nixl_compatibility_hash
(
self
.
vllm_config
,
self
.
backend_name
)
self
.
enforce_compat_hash
=
self
.
kv_transfer_config
.
get_from_extra_config
(
"enforce_handshake_compat"
,
True
)
self
.
_tp_size
:
dict
[
EngineId
,
int
]
=
{
self
.
engine_id
:
self
.
world_size
}
self
.
_tp_size
:
dict
[
EngineId
,
int
]
=
{
self
.
engine_id
:
self
.
world_size
}
self
.
_block_size
:
dict
[
EngineId
,
int
]
=
{
self
.
engine_id
:
self
.
block_size
}
self
.
_block_size
:
dict
[
EngineId
,
int
]
=
{
self
.
engine_id
:
self
.
block_size
}
# With heterogeneous TP, P must wait for all assigned D TP workers to
# With heterogeneous TP, P must wait for all assigned D TP workers to
...
@@ -877,14 +976,58 @@ class NixlConnectorWorker:
...
@@ -877,14 +976,58 @@ class NixlConnectorWorker:
# Set receive timeout to 5 seconds to avoid hanging on dead server
# Set receive timeout to 5 seconds to avoid hanging on dead server
sock
.
setsockopt
(
zmq
.
RCVTIMEO
,
5000
)
# milliseconds
sock
.
setsockopt
(
zmq
.
RCVTIMEO
,
5000
)
# milliseconds
sock
.
send
(
msg
)
sock
.
send
(
msg
)
metadata_bytes
=
sock
.
recv
()
handshake_bytes
=
sock
.
recv
()
decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlAgentMetadata
)
metadata
=
decoder
.
decode
(
metadata_bytes
)
# Decode handshake payload to get compatibility hash
handshake_decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlHandshakePayload
)
try
:
handshake_payload
=
handshake_decoder
.
decode
(
handshake_bytes
)
except
(
msgspec
.
DecodeError
,
msgspec
.
ValidationError
)
as
e
:
raise
RuntimeError
(
f
"Failed to decode NixlHandshakePayload. This likely indicates "
f
"an incompatibility between connector version. Error:
{
e
}
"
)
from
e
got_metadata_time
=
time
.
perf_counter
()
got_metadata_time
=
time
.
perf_counter
()
logger
.
debug
(
logger
.
debug
(
"NIXL handshake: get metadata took: %s"
,
got_metadata_time
-
start_time
"NIXL handshake: get metadata took: %s"
,
got_metadata_time
-
start_time
)
)
# Check compatibility hash BEFORE decoding agent metadata
if
(
self
.
enforce_compat_hash
and
handshake_payload
.
compatibility_hash
!=
self
.
compat_hash
):
raise
RuntimeError
(
f
"NIXL compatibility hash mismatch. "
f
"Local:
{
self
.
compat_hash
}
, "
f
"Remote:
{
handshake_payload
.
compatibility_hash
}
. "
f
"Prefill and decode instances have incompatible configurations. "
f
"This may be due to: different vLLM versions, models, dtypes, "
f
"KV cache layouts, attention backends, etc. "
f
"Both instances must use identical configurations."
f
"Disable this check using "
f
'--kv-transfer-config
\'
{{"kv_connector_extra_config": '
f
'{{"enforce_handshake_compat": false}}}}
\'
'
)
logger
.
info
(
"NIXL compatibility check passed (hash: %s)"
,
handshake_payload
.
compatibility_hash
,
)
# Decode agent metadata
metadata_decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlAgentMetadata
)
try
:
metadata
=
metadata_decoder
.
decode
(
handshake_payload
.
agent_metadata_bytes
)
except
(
msgspec
.
DecodeError
,
msgspec
.
ValidationError
)
as
e
:
# This should not happen if hash matched
raise
RuntimeError
(
f
"Failed to decode NixlAgentMetadata. Error:
{
e
}
"
)
from
e
# Ensure engine id matches.
# Ensure engine id matches.
if
metadata
.
engine_id
!=
expected_engine_id
:
if
metadata
.
engine_id
!=
expected_engine_id
:
raise
RuntimeError
(
raise
RuntimeError
(
...
@@ -1175,19 +1318,24 @@ class NixlConnectorWorker:
...
@@ -1175,19 +1318,24 @@ class NixlConnectorWorker:
assert
len
(
self
.
block_window_per_layer
)
==
self
.
num_layers
assert
len
(
self
.
block_window_per_layer
)
==
self
.
num_layers
# After KV Caches registered, listen for new connections.
# After KV Caches registered, listen for new connections.
self
.
xfer_handshake
_metadata
=
NixlAgentMetadata
(
agent
_metadata
=
NixlAgentMetadata
(
engine_id
=
self
.
engine_id
,
engine_id
=
self
.
engine_id
,
agent_metadata
=
self
.
nixl_wrapper
.
get_agent_metadata
(),
agent_metadata
=
self
.
nixl_wrapper
.
get_agent_metadata
(),
kv_caches_base_addr
=
self
.
kv_caches_base_addr
[
self
.
engine_id
],
kv_caches_base_addr
=
self
.
kv_caches_base_addr
[
self
.
engine_id
],
device_id
=
self
.
device_id
,
device_id
=
self
.
device_id
,
num_blocks
=
self
.
num_blocks
,
num_blocks
=
self
.
num_blocks
,
block_lens
=
self
.
block_len_per_layer
,
block_lens
=
self
.
block_len_per_layer
,
attn_backend_name
=
self
.
backend_name
,
kv_cache_layout
=
self
.
kv_cache_layout
kv_cache_layout
=
self
.
kv_cache_layout
if
not
self
.
use_host_buffer
if
not
self
.
use_host_buffer
else
self
.
host_buffer_kv_cache_layout
,
else
self
.
host_buffer_kv_cache_layout
,
block_size
=
self
.
block_size
,
block_size
=
self
.
block_size
,
)
)
# Wrap metadata in payload with hash for defensive decoding
encoder
=
msgspec
.
msgpack
.
Encoder
()
self
.
xfer_handshake_metadata
=
NixlHandshakePayload
(
compatibility_hash
=
self
.
compat_hash
,
agent_metadata_bytes
=
encoder
.
encode
(
agent_metadata
),
)
def
register_local_xfer_handler
(
def
register_local_xfer_handler
(
self
,
self
,
...
@@ -1402,8 +1550,6 @@ class NixlConnectorWorker:
...
@@ -1402,8 +1550,6 @@ class NixlConnectorWorker:
remote_engine_id
=
nixl_agent_meta
.
engine_id
remote_engine_id
=
nixl_agent_meta
.
engine_id
assert
self
.
_tp_size
[
remote_engine_id
]
==
remote_tp_size
assert
self
.
_tp_size
[
remote_engine_id
]
==
remote_tp_size
# TODO We may eventually want to skip enforcing the same attn backend.
assert
nixl_agent_meta
.
attn_backend_name
==
self
.
backend_name
tp_ratio
=
self
.
kv_topo
.
tp_ratio_from_engine_id
(
remote_engine_id
)
tp_ratio
=
self
.
kv_topo
.
tp_ratio_from_engine_id
(
remote_engine_id
)
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment