Unverified Commit 3244a2eb authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[KVConnector][NIXL] Organize NIXL connector into its own directory (#39354)



The number of features supported by the connector has grown substantially
and the `nixl_connector.py` file has accumulated a lot of code. Creates a separate
directory and isolates connector/scheduler code in the hope of improving clarity
and maintainability.

Further refactor of components aimed at improving clarity and simplifying code
will follow soon.
Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent 72ff142c
...@@ -72,4 +72,4 @@ For the PD disaggregation part, the Prefill instance receives cache exactly the ...@@ -72,4 +72,4 @@ For the PD disaggregation part, the Prefill instance receives cache exactly the
`docs/features/disagg_prefill.md` shows the brief idea about the disaggregated prefill (v0) `docs/features/disagg_prefill.md` shows the brief idea about the disaggregated prefill (v0)
We create the example setup with the **NixlConnector** from `vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py` and referred to the `tests/v1/kv_connector/nixl_integration/toy_proxy_server.py` to facilitate the kv transfer between P and D; We create the example setup with the **NixlConnector** from `vllm/distributed/kv_transfer/kv_connector/v1/nixl/` and referred to the `tests/v1/kv_connector/nixl_integration/toy_proxy_server.py` to facilitate the kv transfer between P and D;
...@@ -19,7 +19,7 @@ METRIC_SOURCE_FILES = [ ...@@ -19,7 +19,7 @@ METRIC_SOURCE_FILES = [
"output": "spec_decode.inc.md", "output": "spec_decode.inc.md",
}, },
{ {
"path": "vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py", "path": "vllm/distributed/kv_transfer/kv_connector/v1/nixl/stats.py",
"output": "nixl_connector.inc.md", "output": "nixl_connector.inc.md",
}, },
{"path": "vllm/v1/metrics/perf.py", "output": "perf.inc.md"}, {"path": "vllm/v1/metrics/perf.py", "output": "perf.inc.md"},
......
...@@ -21,7 +21,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( ...@@ -21,7 +21,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiKVConnectorStats, MultiKVConnectorStats,
MultiKVConnectorWorkerMetadata, MultiKVConnectorWorkerMetadata,
) )
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl import (
NixlKVConnectorStats, NixlKVConnectorStats,
) )
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
......
...@@ -24,13 +24,13 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( ...@@ -24,13 +24,13 @@ from vllm.distributed.kv_transfer.kv_connector.utils import (
TpKVTopology, TpKVTopology,
get_current_attn_backend, get_current_attn_backend,
) )
from vllm.distributed.kv_transfer.kv_connector.v1 import nixl_connector from vllm.distributed.kv_transfer.kv_connector.v1 import nixl
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiKVConnectorStats, MultiKVConnectorStats,
) )
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl import (
KVConnectorRole,
NixlAgentMetadata, NixlAgentMetadata,
NixlConnector, NixlConnector,
NixlConnectorMetadata, NixlConnectorMetadata,
...@@ -38,6 +38,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( ...@@ -38,6 +38,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorWorker, NixlConnectorWorker,
NixlHandshakePayload, NixlHandshakePayload,
NixlKVConnectorStats, NixlKVConnectorStats,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import (
compute_nixl_compatibility_hash, compute_nixl_compatibility_hash,
) )
from vllm.distributed.kv_transfer.kv_transfer_state import ( from vllm.distributed.kv_transfer.kv_transfer_state import (
...@@ -320,7 +322,7 @@ def test_prompt_less_than_block_size(): ...@@ -320,7 +322,7 @@ def test_prompt_less_than_block_size():
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_kv_transfer_handshake(dist_init): def test_kv_transfer_handshake(dist_init):
...@@ -534,7 +536,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): ...@@ -534,7 +536,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
class TestNixlHandshake: class TestNixlHandshake:
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_multi_xfer_one_engine( def test_multi_xfer_one_engine(
...@@ -621,7 +623,7 @@ class TestNixlHandshake: ...@@ -621,7 +623,7 @@ class TestNixlHandshake:
connector.clear_connector_metadata() connector.clear_connector_metadata()
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -691,7 +693,7 @@ class TestNixlHandshake: ...@@ -691,7 +693,7 @@ class TestNixlHandshake:
raise TimeoutError("Took too long to complete async handshake.") raise TimeoutError("Took too long to complete async handshake.")
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
@pytest.mark.parametrize("local_tp_size", [1, 2]) @pytest.mark.parametrize("local_tp_size", [1, 2])
...@@ -703,7 +705,7 @@ class TestNixlHandshake: ...@@ -703,7 +705,7 @@ class TestNixlHandshake:
remote configurations. remote configurations.
""" """
monkeypatch.setattr( monkeypatch.setattr(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.get_tensor_model_parallel_world_size",
lambda: local_tp_size, lambda: local_tp_size,
) )
...@@ -760,7 +762,7 @@ class TestNixlHandshake: ...@@ -760,7 +762,7 @@ class TestNixlHandshake:
check_handshake(6) check_handshake(6)
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_prefill_tp_size_greater_than_decode_tp_size_mla( def test_prefill_tp_size_greater_than_decode_tp_size_mla(
...@@ -863,7 +865,7 @@ class TestNixlHandshake: ...@@ -863,7 +865,7 @@ class TestNixlHandshake:
assert req_id not in conn_p1.connector_worker._reqs_to_process assert req_id not in conn_p1.connector_worker._reqs_to_process
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_concurrent_load_kv( def test_concurrent_load_kv(
...@@ -928,7 +930,7 @@ class TestNixlHandshake: ...@@ -928,7 +930,7 @@ class TestNixlHandshake:
raise TimeoutError("Took too long to complete async handshake.") raise TimeoutError("Took too long to complete async handshake.")
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_handshake_fails_on_kv_cache_layout_mismatch( def test_handshake_fails_on_kv_cache_layout_mismatch(
...@@ -943,7 +945,7 @@ class TestNixlHandshake: ...@@ -943,7 +945,7 @@ class TestNixlHandshake:
# Mock TP world size to 2 to force heterogeneous TP when # Mock TP world size to 2 to force heterogeneous TP when
# remote_tp_size=1 # remote_tp_size=1
with patch( with patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501 "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.get_tensor_model_parallel_world_size", # noqa: E501
return_value=2, return_value=2,
): ):
# Initialize connector and worker (with fake NIXL wrapper) # Initialize connector and worker (with fake NIXL wrapper)
...@@ -982,7 +984,7 @@ class TestNixlHandshake: ...@@ -982,7 +984,7 @@ class TestNixlHandshake:
worker.add_remote_agent(meta, remote_tp_size=1) worker.add_remote_agent(meta, remote_tp_size=1)
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental( def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
...@@ -997,7 +999,7 @@ class TestNixlHandshake: ...@@ -997,7 +999,7 @@ class TestNixlHandshake:
# Mock TP world size to 2 to force heterogeneous TP when # Mock TP world size to 2 to force heterogeneous TP when
# remote_tp_size=1 # remote_tp_size=1
with patch( with patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501 "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.get_tensor_model_parallel_world_size", # noqa: E501
return_value=2, return_value=2,
): ):
# Initialize connector and worker (with fake NIXL wrapper) # Initialize connector and worker (with fake NIXL wrapper)
...@@ -1042,7 +1044,7 @@ class TestNixlHandshake: ...@@ -1042,7 +1044,7 @@ class TestNixlHandshake:
# we put here is important. First run ray, it will clean up the resources, then # we put here is important. First run ray, it will clean up the resources, then
# the rest of the tests. # the rest of the tests.
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_kv_connector_stats(default_vllm_config, dist_init): def test_kv_connector_stats(default_vllm_config, dist_init):
...@@ -1227,8 +1229,8 @@ def test_multi_kv_connector_stats_aggregation(): ...@@ -1227,8 +1229,8 @@ def test_multi_kv_connector_stats_aggregation():
worker_patterns = [(2, 1), (3, 0), (0, 5)] # (Nixl, Foo) worker_patterns = [(2, 1), (3, 0), (0, 5)] # (Nixl, Foo)
worker_outputs: list[ModelRunnerOutput] = [] worker_outputs: list[ModelRunnerOutput] = []
for i, (nixl, foo) in enumerate(worker_patterns): for i, (nixl_count, foo) in enumerate(worker_patterns):
stats = make_multi_stats(nixl, foo) stats = make_multi_stats(nixl_count, foo)
output = ModelRunnerOutput( output = ModelRunnerOutput(
req_ids=[f"req_{i}"], req_ids=[f"req_{i}"],
req_id_to_index={f"req_{i}": 0}, req_id_to_index={f"req_{i}": 0},
...@@ -1256,7 +1258,7 @@ def test_multi_kv_connector_stats_aggregation(): ...@@ -1256,7 +1258,7 @@ def test_multi_kv_connector_stats_aggregation():
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_scheduler_kv_connector_stats_aggregation(): def test_scheduler_kv_connector_stats_aggregation():
...@@ -1324,7 +1326,7 @@ def test_scheduler_kv_connector_stats_aggregation(): ...@@ -1324,7 +1326,7 @@ def test_scheduler_kv_connector_stats_aggregation():
@pytest.mark.parametrize("distributed_executor_backend", ["ray", None]) @pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend): def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
...@@ -1513,13 +1515,14 @@ def test_register_kv_caches( ...@@ -1513,13 +1515,14 @@ def test_register_kv_caches(
backend_cls = TritonAttentionBackend backend_cls = TritonAttentionBackend
nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector" nixl_worker = "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker"
nixl_connector = "vllm.distributed.kv_transfer.kv_connector.v1.nixl.connector"
with ( with (
patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper, patch(f"{nixl_worker}.NixlWrapper") as mock_nixl_wrapper,
patch(f"{nixl_module}.threading.Event"), patch(f"{nixl_worker}.threading.Event"),
patch(f"{nixl_module}.threading.Thread") as mock_thread, patch(f"{nixl_worker}.threading.Thread") as mock_thread,
patch(f"{nixl_module}.get_current_attn_backend") as mock_get_attn_backend, patch(f"{nixl_connector}.get_current_attn_backend") as mock_get_attn_backend,
patch(f"{nixl_module}.get_current_attn_backends") as mock_get_attn_backends, patch(f"{nixl_worker}.get_current_attn_backends") as mock_get_attn_backends,
): ):
# Ensure get_attn_backend returns the correct value due to # Ensure get_attn_backend returns the correct value due to
# _cached_get_attn_backend returning the backend from previous # _cached_get_attn_backend returning the backend from previous
...@@ -1754,28 +1757,26 @@ def test_kv_buffer_to_nixl_memory_types( ...@@ -1754,28 +1757,26 @@ def test_kv_buffer_to_nixl_memory_types(
vllm_config = create_vllm_config() vllm_config = create_vllm_config()
# Override the default memory types in the config # Override the default memory types in the config
vllm_config.kv_transfer_config.kv_buffer_device = kv_buffer_device vllm_config.kv_transfer_config.kv_buffer_device = kv_buffer_device
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl.utils import (
_NIXL_SUPPORTED_DEVICE, _NIXL_SUPPORTED_DEVICE,
) )
_NIXL_SUPPORTED_DEVICE.update(FakePlatform.get_nixl_supported_devices()) _NIXL_SUPPORTED_DEVICE.update(FakePlatform.get_nixl_supported_devices())
with ( with (
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper"),
patch( patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper" "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.threading.Event"
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"
), ),
patch( patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread" "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.threading.Thread"
), ),
patch( patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.current_platform",
FakePlatform, FakePlatform,
), ),
patch( patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector._NIXL_SUPPORTED_DEVICE", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.utils._NIXL_SUPPORTED_DEVICE",
_NIXL_SUPPORTED_DEVICE, _NIXL_SUPPORTED_DEVICE,
), ),
): # noqa: E501 ): # noqa: E501
...@@ -1790,7 +1791,7 @@ def test_kv_buffer_to_nixl_memory_types( ...@@ -1790,7 +1791,7 @@ def test_kv_buffer_to_nixl_memory_types(
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_shutdown_cleans_up_resources(default_vllm_config, dist_init): def test_shutdown_cleans_up_resources(default_vllm_config, dist_init):
...@@ -1855,7 +1856,7 @@ def test_shutdown_cleans_up_resources(default_vllm_config, dist_init): ...@@ -1855,7 +1856,7 @@ def test_shutdown_cleans_up_resources(default_vllm_config, dist_init):
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_init): def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_init):
...@@ -1975,7 +1976,7 @@ class FailingNixlWrapper(FakeNixlWrapper): ...@@ -1975,7 +1976,7 @@ class FailingNixlWrapper(FakeNixlWrapper):
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FailingNixlWrapper, FailingNixlWrapper,
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -2065,10 +2066,10 @@ def test_transfer_failure_logging( ...@@ -2065,10 +2066,10 @@ def test_transfer_failure_logging(
slot_mapping={}, slot_mapping={},
) )
# Capture logs from the nixl_connector logger specifically # Capture logs from the nixl.worker logger specifically
# vLLM loggers have propagate=False, so we need to capture directly # vLLM loggers have propagate=False, so we need to capture directly
nixl_logger = logging.getLogger( nixl_logger = logging.getLogger(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector" "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker"
) )
captured_logs: list[logging.LogRecord] = [] captured_logs: list[logging.LogRecord] = []
...@@ -2130,7 +2131,7 @@ def test_transfer_failure_logging( ...@@ -2130,7 +2131,7 @@ def test_transfer_failure_logging(
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FailingNixlWrapper, FailingNixlWrapper,
) )
def test_handshake_failure_returns_finished(default_vllm_config, dist_init): def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
...@@ -2181,7 +2182,7 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init): ...@@ -2181,7 +2182,7 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FailingNixlWrapper, FailingNixlWrapper,
) )
def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init): def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init):
...@@ -2257,7 +2258,7 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init) ...@@ -2257,7 +2258,7 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
], ],
) )
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_compatibility_hash_validation( def test_compatibility_hash_validation(
...@@ -2328,7 +2329,7 @@ def test_compatibility_hash_validation( ...@@ -2328,7 +2329,7 @@ def test_compatibility_hash_validation(
elif "connector_version" in version_override: elif "connector_version" in version_override:
stack.enter_context( stack.enter_context(
patch.object( patch.object(
nixl_connector, nixl.metadata,
"NIXL_CONNECTOR_VERSION", "NIXL_CONNECTOR_VERSION",
version_override["connector_version"], version_override["connector_version"],
) )
...@@ -2365,7 +2366,7 @@ def test_compatibility_hash_validation( ...@@ -2365,7 +2366,7 @@ def test_compatibility_hash_validation(
# Patch zmq_ctx to return our mock socket # Patch zmq_ctx to return our mock socket
with ( with (
patch.object(decode_worker, "add_remote_agent", return_value="fake_agent"), patch.object(decode_worker, "add_remote_agent", return_value="fake_agent"),
patch.object(nixl_connector, "zmq_ctx") as mock_zmq_ctx, patch.object(nixl.worker, "zmq_ctx") as mock_zmq_ctx,
): ):
mock_zmq_ctx.return_value.__enter__.return_value = mock_socket mock_zmq_ctx.return_value.__enter__.return_value = mock_socket
...@@ -2399,7 +2400,7 @@ def test_compatibility_hash_validation( ...@@ -2399,7 +2400,7 @@ def test_compatibility_hash_validation(
], ],
) )
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario): def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario):
...@@ -2464,7 +2465,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) ...@@ -2464,7 +2465,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
mock_socket.recv.return_value = msg_bytes mock_socket.recv.return_value = msg_bytes
with ( with (
patch.object(decode_worker, "add_remote_agent", return_value="fake_agent"), patch.object(decode_worker, "add_remote_agent", return_value="fake_agent"),
patch.object(nixl_connector, "zmq_ctx") as mock_zmq_ctx, patch.object(nixl.worker, "zmq_ctx") as mock_zmq_ctx,
): ):
mock_zmq_ctx.return_value.__enter__.return_value = mock_socket mock_zmq_ctx.return_value.__enter__.return_value = mock_socket
......
...@@ -31,10 +31,10 @@ from .utils import ( ...@@ -31,10 +31,10 @@ from .utils import (
(False, [0]), (False, [0]),
], ],
) )
@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform") @patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler.current_platform")
def test_sw_sizes(mock_platform, swa_enabled, expected_sw_sizes): def test_sw_sizes(mock_platform, swa_enabled, expected_sw_sizes):
"""Test sw_sizes is correctly computed based on SWA enabled/disabled.""" """Test sw_sizes is correctly computed based on SWA enabled/disabled."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler import (
NixlConnectorScheduler, NixlConnectorScheduler,
) )
...@@ -65,7 +65,7 @@ def test_logical_to_kernel_block_ids_with_hma(): ...@@ -65,7 +65,7 @@ def test_logical_to_kernel_block_ids_with_hma():
When HMA is enabled, the logical block size may differ from the kernel When HMA is enabled, the logical block size may differ from the kernel
block size. Each logical block maps to multiple kernel blocks. block size. Each logical block maps to multiple kernel blocks.
""" """
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
NixlConnectorWorker, NixlConnectorWorker,
) )
...@@ -169,7 +169,7 @@ def test_nixl_metadata_hma_block_ids_structure(): ...@@ -169,7 +169,7 @@ def test_nixl_metadata_hma_block_ids_structure():
Test that NixlConnectorMetadata correctly stores block IDs for multiple Test that NixlConnectorMetadata correctly stores block IDs for multiple
KV cache groups when HMA is enabled. KV cache groups when HMA is enabled.
""" """
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import (
NixlConnectorMetadata, NixlConnectorMetadata,
) )
...@@ -211,7 +211,7 @@ def test_nixl_metadata_hma_block_ids_structure(): ...@@ -211,7 +211,7 @@ def test_nixl_metadata_hma_block_ids_structure():
def test_get_block_descs_ids_hybrid_ssm(): def test_get_block_descs_ids_hybrid_ssm():
"""Test _get_block_descs_ids uses per-group strides for hybrid FA+SSM """Test _get_block_descs_ids uses per-group strides for hybrid FA+SSM
when ratio=1 (no kernel block size mismatch).""" when ratio=1 (no kernel block size mismatch)."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
NixlConnectorWorker, NixlConnectorWorker,
) )
...@@ -247,7 +247,7 @@ def test_get_block_descs_ids_hybrid_ssm(): ...@@ -247,7 +247,7 @@ def test_get_block_descs_ids_hybrid_ssm():
def test_get_block_descs_ids_kernel_block_mismatch(): def test_get_block_descs_ids_kernel_block_mismatch():
"""Test _get_block_descs_ids uses different strides for FA (kernel blocks) """Test _get_block_descs_ids uses different strides for FA (kernel blocks)
vs SSM (logical blocks) when ratio > 1.""" vs SSM (logical blocks) when ratio > 1."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
NixlConnectorWorker, NixlConnectorWorker,
) )
...@@ -284,7 +284,7 @@ def test_get_block_descs_ids_kernel_block_mismatch(): ...@@ -284,7 +284,7 @@ def test_get_block_descs_ids_kernel_block_mismatch():
def test_nixl_metadata_hybrid_ssm_block_ids(): def test_nixl_metadata_hybrid_ssm_block_ids():
"""Test NixlConnectorMetadata correctly stores block IDs for FA + SSM """Test NixlConnectorMetadata correctly stores block IDs for FA + SSM
groups with different block counts (kernel mismatch active).""" groups with different block counts (kernel mismatch active)."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import (
NixlConnectorMetadata, NixlConnectorMetadata,
) )
...@@ -392,7 +392,7 @@ def test_mamba_n1_p_side_truncation(): ...@@ -392,7 +392,7 @@ def test_mamba_n1_p_side_truncation():
], ],
ids=["fa_swa_mamba", "fa_swa_only", "fa_only"], ids=["fa_swa_mamba", "fa_swa_only", "fa_only"],
) )
@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform") @patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler.current_platform")
def test_has_mamba_init( def test_has_mamba_init(
mock_platform, mock_platform,
swa_enabled, swa_enabled,
...@@ -401,7 +401,7 @@ def test_has_mamba_init( ...@@ -401,7 +401,7 @@ def test_has_mamba_init(
expected_is_hma, expected_is_hma,
): ):
"""Test _has_mamba / _is_hma_required derived from kv_cache_groups.""" """Test _has_mamba / _is_hma_required derived from kv_cache_groups."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler import (
NixlConnectorScheduler, NixlConnectorScheduler,
) )
......
...@@ -587,7 +587,7 @@ def test_cannot_recv(): ...@@ -587,7 +587,7 @@ def test_cannot_recv():
assert_scheduler_empty(scheduler) assert_scheduler_empty(scheduler)
@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform") @patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler.current_platform")
def test_p_side_chunked_prefill_mamba(mock_platform): def test_p_side_chunked_prefill_mamba(mock_platform):
"""P-side integration: Mamba N-1 truncation + chunked prefill completes. """P-side integration: Mamba N-1 truncation + chunked prefill completes.
......
...@@ -476,7 +476,7 @@ def make_nixl_scheduler(has_mamba: bool = False, is_hma_required: bool = False): ...@@ -476,7 +476,7 @@ def make_nixl_scheduler(has_mamba: bool = False, is_hma_required: bool = False):
Only sets the two flags needed by the N-1 prefill logic. Only sets the two flags needed by the N-1 prefill logic.
""" """
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler import (
NixlConnectorScheduler, NixlConnectorScheduler,
) )
......
...@@ -178,7 +178,7 @@ KVConnectorFactory.register_connector( ...@@ -178,7 +178,7 @@ KVConnectorFactory.register_connector(
KVConnectorFactory.register_connector( KVConnectorFactory.register_connector(
"NixlConnector", "NixlConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector", "vllm.distributed.kv_transfer.kv_connector.v1.nixl",
"NixlConnector", "NixlConnector",
) )
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""NIXL KV-cache transfer connector (disaggregated prefill / decode)."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.connector import (
NixlConnector,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import (
NixlAgentMetadata,
NixlConnectorMetadata,
NixlHandshakePayload,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler import (
NixlConnectorScheduler,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.stats import (
NixlKVConnectorStats,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
NixlConnectorWorker,
)
__all__ = [
"NixlAgentMetadata",
"NixlConnector",
"NixlConnectorMetadata",
"NixlConnectorScheduler",
"NixlConnectorWorker",
"NixlHandshakePayload",
"NixlKVConnectorStats",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""NixlConnector – thin facade that delegates to scheduler / worker."""
from typing import TYPE_CHECKING, Any
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import (
EngineId,
get_current_attn_backend,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp,
KVConnectorBase_V1,
KVConnectorHandshakeMetadata,
KVConnectorMetadata,
KVConnectorRole,
SupportsHMA,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import (
NixlConnectorMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler import (
NixlConnectorScheduler,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.stats import (
NixlKVConnectorStats,
NixlPromMetrics,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
NixlConnectorWorker,
)
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import MambaSpec
if TYPE_CHECKING:
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
class NixlConnector(KVConnectorBase_V1, SupportsHMA):
@property
def prefer_cross_layer_blocks(self) -> bool:
if any(
[
isinstance(group.kv_cache_spec, MambaSpec)
for group in self.kv_cache_config.kv_cache_groups
]
):
# Hybrid SSM models do not yet support cross-layer layout
return False
backend = get_current_attn_backend(self._vllm_config)
if backend.get_name() not in (
"FLASH_ATTN",
"FLASHINFER",
"TRITON_ATTN",
):
return False
# For now there is no benefit to run cross layers when backend
# does not support on HND
if get_kv_cache_layout() != "HND":
return False
extra_config = self.kv_transfer_config.kv_connector_extra_config
return (
str(extra_config.get("enable_cross_layers_blocks", "False")).lower()
== "true"
)
def __init__(
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig",
):
super().__init__(vllm_config, role, kv_cache_config)
assert vllm_config.kv_transfer_config is not None
assert vllm_config.kv_transfer_config.engine_id is not None
self.kv_cache_config = kv_cache_config
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
self.kv_transfer_config = vllm_config.kv_transfer_config
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler: NixlConnectorScheduler | None = (
NixlConnectorScheduler(vllm_config, self.engine_id, kv_cache_config)
)
self.connector_worker: NixlConnectorWorker | None = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = NixlConnectorWorker(
vllm_config, self.engine_id, kv_cache_config
)
############################################################
# Class Methods
############################################################
@classmethod
def get_required_kvcache_layout(cls, vllm_config: VllmConfig):
if vllm_config.model_config is None:
logger.warning_once(
"Unable to detect current VLLM config. "
"Fallback to default kv cache layout."
)
return None
use_mla = vllm_config.model_config.use_mla
if use_mla:
# return None when we have mla
# as the layout should not matter in that case,
# which fallback to the default behavior.
return None
logger.info_once(
"NixlConnector setting KV cache layout to HND for better xfer performance."
)
return "HND"
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int | None, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens
)
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens
)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, (block_ids,))
def request_finished_all_groups(
self,
request: "Request",
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
"""
Set the KV connector handshake metadata for this connector.
Args:
metadata (dict): the handshake metadata to set.
"""
assert self.connector_scheduler is not None
self.connector_scheduler.set_xfer_handshake_metadata(metadata)
############################################################
# Worker Side Methods
############################################################
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
assert self.connector_worker is not None
self.connector_worker.register_cross_layers_kv_caches(kv_cache)
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
assert self.connector_worker is not None
self.connector_worker.set_host_xfer_buffer_ops(copy_operation)
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
return self.connector_worker.get_finished()
def get_block_ids_with_load_errors(self) -> set[int]:
"""Get block IDs that failed to load via NIXL."""
assert self.connector_worker is not None
return self.connector_worker.get_block_ids_with_load_errors()
def get_kv_connector_stats(self) -> KVConnectorStats | None:
if self.connector_worker is None:
return None
return self.connector_worker.get_kv_connector_stats()
@classmethod
def build_kv_connector_stats(
cls, data: dict[str, Any] | None = None
) -> KVConnectorStats | None:
return (
NixlKVConnectorStats(data=data)
if data is not None
else NixlKVConnectorStats()
)
@classmethod
def build_prom_metrics(
cls,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]],
) -> KVConnectorPromMetrics:
return NixlPromMetrics(
vllm_config, metric_types, labelnames, per_engine_labelvalues
)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, NixlConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata)
def wait_for_layer_load(self, layer_name: str) -> None:
"""NixlConnector does not do layerwise saving."""
pass
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
"""NixlConnector does not save explicitly."""
pass
def wait_for_save(self):
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, NixlConnectorMetadata)
if self.connector_worker.use_host_buffer and self.connector_worker.copy_blocks:
self.connector_worker.save_kv_to_host(self._connector_metadata)
def shutdown(self):
if self.connector_worker is not None:
self.connector_worker.shutdown()
if self.connector_scheduler is not None:
self.connector_scheduler.shutdown()
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
"""
Get the KVConnector handshake metadata for this connector.
This metadata is used for out-of-band connector handshake
between P/D workers.
Returns:
KVConnectorHandshakeMetadata: the handshake metadata.
None if no handshake metadata is available.
"""
assert self.connector_worker is not None
return self.connector_worker.xfer_handshake_metadata
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Metadata dataclasses and helpers for the NIXL connector."""
from dataclasses import dataclass
from typing import Any
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import BlockIds
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorHandshakeMetadata,
KVConnectorMetadata,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
TransferHandle = int
ReqId = str
GET_META_MSG = b"get_meta_msg"
#
# NIXL Connector Version
#
# Increment this version whenever there is an incompatible change to:
# - NixlAgentMetadata schema
# - kv_transfer_params schema or semantics
# - NIXL transfer protocol or wire format
# - KV cache memory layout or block organization
# - Any other change that breaks P/D interoperability
#
# Version History:
# 1: Initial version with compatibility checking
# 2: Add remote_request_id to kv_transfer_params
#
NIXL_CONNECTOR_VERSION: int = 2
@dataclass
class NixlAgentMetadata:
engine_id: str
agent_metadata: bytes
kv_caches_base_addr: list[int]
device_id: int
num_blocks: int
block_lens: list[int]
kv_cache_layout: str
block_size: int
ssm_sizes: tuple[int, int]
attn_backend_name: str
@dataclass
class NixlHandshakePayload(KVConnectorHandshakeMetadata):
"""
Wrapper for NIXL handshake sent over the wire.
Enables two-phase decoding for graceful compatibility checking:
1. Decode NixlHandshakePayload to get compatibility_hash
2. Compute local hash and compare
3. Only if hashes match, decode agent_metadata_bytes
This prevents decoder errors when NixlAgentMetadata schema is
incompatible, allowing graceful failure with clear error message.
"""
compatibility_hash: str
agent_metadata_bytes: bytes # NixlAgentMetadata encoded
def compute_nixl_compatibility_hash(
vllm_config: VllmConfig, attn_backend_name: str, cross_layers_blocks: bool
) -> str:
"""
Compute compatibility hash for NIXL KV transfer.
Hash only the factors that affect whether two NIXL instances can
successfully transfer KV cache data.
Factors included:
- vLLM version and NIXL connector version
- Model architecture (name, dtype, KV heads, layers)
- KV cache format (dtype, sliding window)
- Attention backend
Note: Factors like tensor_parallel_size, block_size, and kv_cache_layout
are validated at runtime in _validate_remote_agent_handshake and are not
included in this hash to support heterogeneous deployments.
Note - the set of factors are likely to evolve significantly over
time to be more or less permissive.
Returns:
SHA-256 hex digest
"""
from vllm import __version__ as vllm_version
from vllm.config.utils import hash_factors
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
is_hma_enabled = not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager
factors = {
# Version compatibility
"vllm_version": vllm_version,
"nixl_connector_version": NIXL_CONNECTOR_VERSION,
# Model architecture - affects KV cache shape
"model": model_config.model,
"dtype": str(model_config.dtype),
"num_kv_heads": model_config.get_total_num_kv_heads(),
"head_size": model_config.get_head_size(),
"num_hidden_layers": model_config.get_total_num_hidden_layers(),
# Attention backend and KV cache dtype affect memory layout
"attn_backend_name": attn_backend_name,
"cache_dtype": str(cache_config.cache_dtype),
"cross_layers_blocks": cross_layers_blocks,
"is_hma_enabled": is_hma_enabled,
}
compat_hash = hash_factors(factors)
logger.debug(
"NIXL compatibility hash: %s (model=%s, dtype=%s, num_kv_heads=%d, "
"cache_dtype=%s, attn_backend=%s)",
compat_hash,
factors["model"],
factors["dtype"],
factors["num_kv_heads"],
factors["cache_dtype"],
attn_backend_name,
)
return compat_hash
@dataclass
class RemoteMeta:
block_ids: BlockIds
host: str
port: int
engine_id: str
request_id: str
@dataclass
class ReqMeta:
local_block_ids: BlockIds
# To be used when logical block size does not match the kernel block size
local_physical_block_ids: BlockIds
tp_size: int
remote: RemoteMeta | None = None
class NixlConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
self.reqs_to_send: dict[ReqId, float] = {}
self.reqs_in_batch: set[ReqId] = set()
self.reqs_not_processed: set[ReqId] = set()
def _add_new_req(
self,
local_block_ids: BlockIds,
kv_transfer_params: dict[str, Any],
) -> ReqMeta:
return ReqMeta(
local_block_ids=local_block_ids,
local_physical_block_ids=local_block_ids,
# P workers don't need to receive tp_size from proxy here.
tp_size=kv_transfer_params.get("tp_size", 1),
)
def add_new_req_to_save(
self,
request_id: ReqId,
local_block_ids: BlockIds,
kv_transfer_params: dict[str, Any],
):
self.reqs_to_save[request_id] = self._add_new_req(
local_block_ids, kv_transfer_params
)
def add_new_req_to_recv(
self,
request_id: ReqId,
local_block_ids: BlockIds,
kv_transfer_params: dict[str, Any],
):
req = self._add_new_req(local_block_ids, kv_transfer_params)
req.remote = RemoteMeta(
block_ids=kv_transfer_params["remote_block_ids"],
engine_id=kv_transfer_params["remote_engine_id"],
request_id=kv_transfer_params["remote_request_id"],
host=kv_transfer_params["remote_host"],
port=kv_transfer_params["remote_port"],
)
self.reqs_to_recv[request_id] = req
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Scheduler-side logic for the NIXL connector."""
import threading
import time
from typing import TYPE_CHECKING, Any
import msgspec
import zmq
from vllm import envs
from vllm.distributed.kv_transfer.kv_connector.utils import (
BlockIds,
EngineId,
yield_req_data,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorHandshakeMetadata,
KVConnectorMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import (
GET_META_MSG,
NixlConnectorMetadata,
NixlHandshakePayload,
ReqId,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.utils import zmq_ctx
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.network_utils import make_zmq_path
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
MambaSpec,
SlidingWindowSpec,
)
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
class NixlConnectorScheduler:
"""Implementation of Scheduler side methods"""
def __init__(
self,
vllm_config: "VllmConfig",
engine_id: str,
kv_cache_config: "KVCacheConfig",
):
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.engine_id: EngineId = engine_id
self.kv_cache_config = kv_cache_config
self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
self.side_channel_port = (
envs.VLLM_NIXL_SIDE_CHANNEL_PORT
+ vllm_config.parallel_config.data_parallel_index
)
assert vllm_config.kv_transfer_config is not None
if current_platform.device_type == "cpu":
self.use_host_buffer = False
else:
self.use_host_buffer = (
vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
)
self._is_hma_required = (
not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager
# Also handle unlikely SW-only model case instead of checking num_groups>1.
and any(
not isinstance(g.kv_cache_spec, FullAttentionSpec)
for g in kv_cache_config.kv_cache_groups
)
)
self._has_mamba = any(
isinstance(g.kv_cache_spec, MambaSpec)
for g in kv_cache_config.kv_cache_groups
)
logger.info("Initializing NIXL Scheduler %s", engine_id)
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
logger.info("Hybrid Memory Allocator is enabled with NIXL")
# Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: threading.Thread | None = None
self._stop_event = threading.Event()
# Requests that need to start recv/send.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[ReqId, tuple[Request, BlockIds]] = {}
self._reqs_need_save: dict[ReqId, Request] = {}
# Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {}
self._reqs_in_batch: set[ReqId] = set()
# Reqs to remove from processed set because they're not to send after
# remote prefill or aborted.
self._reqs_not_processed: set[ReqId] = set()
# Gather Sliding Window sizes for each kv cache group (if any) in number of
# blocks per KV cache group. This is used to clip the local attention window.
sw_sizes_tokens: list[tuple[int, int]] = [
(g.kv_cache_spec.sliding_window, g.kv_cache_spec.block_size)
if isinstance(g.kv_cache_spec, SlidingWindowSpec)
else (0, self.block_size)
for g in kv_cache_config.kv_cache_groups
]
# cdiv(n_tokens, block_size) gives blocks/window; add 1 to conservatively
# account for boundary overlap eg window isn't fully aligned with blocks.
self.blocks_per_sw = [
cdiv(n_tokens, block_size) + 1 if n_tokens else 0
for n_tokens, block_size in sw_sizes_tokens
]
def shutdown(self):
self._stop_event.set()
if self._nixl_handshake_listener_t is not None:
self._nixl_handshake_listener_t.join()
self._nixl_handshake_listener_t = None
def get_sw_clipped_blocks(self, block_ids: BlockIds) -> BlockIds:
"""
Clip the number of blocks to the sliding window size for each kv cache group
that employs SWA.
This is necessary because the KV Cache manager initially allocates blocks for
the entire sequence length, and successively cleans up blocks that are outside
the window prior to the `request_finished_all_groups` hook.
"""
if len(block_ids) == 0 or not self._is_hma_required:
# No blocks to clip eg Full prefix cache hit or not a hybrid model.
return block_ids
# NOTE (NickLucche) This logic is currently handled at the connector level
# because offloading connectors might want to receive the whole sequence even
# for SWA groups. We will abstract this logic once the interface is more stable
assert len(block_ids) == len(self.blocks_per_sw), (
"Number of KV cache groups must match"
)
# For non-SWA groups, blocks_per_sw is 0 so we return all block_ids unchanged
return tuple(
[
blocks[-self.blocks_per_sw[i] :]
if self.blocks_per_sw[i] > 0
else blocks
for i, blocks in enumerate(block_ids)
]
)
def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
"""
Set the KV connector handshake metadata for this connector.
Args:
metadata (dict): the handshake metadata to set.
"""
encoded_data: dict[int, bytes] = {}
encoder = msgspec.msgpack.Encoder()
for tp_rank, rank_metadata in metadata.items():
if not isinstance(rank_metadata, NixlHandshakePayload):
raise ValueError(
"NixlConnectorScheduler expects NixlHandshakePayload for "
"handshake metadata."
)
encoded_data[tp_rank] = encoder.encode(rank_metadata)
logger.debug(
"Tp rank %d: encoded NixlHandshakePayload size: %s bytes",
tp_rank,
str(len(encoded_data[tp_rank])),
)
# Only start the listener when we have metadata to serve.
if self._nixl_handshake_listener_t is None:
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
args=(
encoded_data,
ready_event,
self._stop_event,
self.side_channel_port,
),
daemon=True,
name="nixl_handshake_listener",
)
self._nixl_handshake_listener_t.start()
ready_event.wait() # Wait for listener ZMQ socket to be ready.
@staticmethod
def _nixl_handshake_listener(
encoded_data: dict[int, Any],
ready_event: threading.Event,
stop_event: threading.Event,
port: int,
):
"""Background thread for getting new NIXL handshakes."""
# NOTE(rob): this is a simple implementation. We will move
# to a better approach via HTTP endpoint soon.
# Listen for new requests for metadata.
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
path = make_zmq_path("tcp", host, port)
logger.debug("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock:
sock.setsockopt(zmq.RCVTIMEO, 1000)
ready_event.set()
while True:
try:
identity, _, msg = sock.recv_multipart()
except zmq.Again:
if stop_event.is_set():
break
continue
# Decode the message which contains (GET_META_MSG, rank)
msg, target_tp_rank = msgspec.msgpack.decode(msg)
logger.debug(
"Received message for tp rank %s",
target_tp_rank,
)
if msg != GET_META_MSG:
logger.warning("Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data[target_tp_rank]))
def _mamba_prefill_token_count(self, num_prompt_tokens: int) -> int:
"""D-side only. Returns N-1 for Mamba models since the decoder
always recomputes the last token and must start from h(N-1)."""
if self._has_mamba and num_prompt_tokens > 1:
return num_prompt_tokens - 1
return num_prompt_tokens
def _truncate_mamba_request_for_prefill(self, request: "Request") -> None:
"""P-side only: drop the last prompt token so the prefiller computes
h(N-1) instead of h(N). The decoder recomputes the last token to
derive h(N) correctly.
Guarded by ``_p_side_truncated`` to avoid repeated truncation if the
request is preempted and rescheduled."""
params = request.kv_transfer_params
if (
params is not None
# Guard against repeated truncation after preemption/reschedule.
and not params.get("_p_side_truncated")
and request.num_prompt_tokens > 1
):
if request.prompt_token_ids is not None:
request.prompt_token_ids.pop()
elif request.prompt_embeds is not None:
request.prompt_embeds = request.prompt_embeds[:-1]
else:
return
request._all_token_ids.pop()
request.num_prompt_tokens -= 1
request.max_tokens = 1
params["_p_side_truncated"] = True
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int, bool]:
"""
For remote prefill, pull all prompt blocks from remote
asynchronously relative to engine execution.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
* the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
* true if the external KV cache tokens will be loaded
asynchronously (between scheduler steps).
"""
params = request.kv_transfer_params
logger.debug(
"NIXLConnector get_num_new_matched_tokens: "
"num_computed_tokens=%s, kv_transfer_params=%s",
num_computed_tokens,
params,
)
if params is not None and params.get("do_remote_prefill"):
# Remote prefill: get all prompt blocks from remote.
token_ids = request.prompt_token_ids or []
actual = self._mamba_prefill_token_count(len(token_ids))
count = actual - num_computed_tokens
if count > 0:
return count, True
if params is not None and params.get("do_remote_decode") and self._has_mamba:
self._truncate_mamba_request_for_prefill(request)
# No remote prefill for this request.
return 0, False
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
params = request.kv_transfer_params
logger.debug(
"NIXLConnector update_state_after_alloc: "
"num_external_tokens=%s, kv_transfer_params=%s",
num_external_tokens,
params,
)
if not params:
return
if params.get("do_remote_decode"):
self._reqs_in_batch.add(request.request_id)
if self.use_host_buffer and params.get("do_remote_decode"):
# NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer.
self._reqs_need_save[request.request_id] = request
elif params.get("do_remote_prefill"):
if params.get("remote_block_ids"):
if all(
p in params
for p in (
"remote_engine_id",
"remote_request_id",
"remote_host",
"remote_port",
)
):
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
# send_notif in _read_blocks to free the memory on the P.
unhashed_local_block_ids: BlockIds = (
blocks.get_unhashed_block_ids_all_groups()
if num_external_tokens > 0
else ()
)
local_block_ids = self.get_sw_clipped_blocks(
unhashed_local_block_ids
)
# Get unhashed blocks to pull from remote. Mind that a full prefix
# cache hit is indicated with an empty list.
self._reqs_need_recv[request.request_id] = (
request,
local_block_ids,
)
else:
logger.warning(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer",
params,
)
else:
assert num_external_tokens == 0
# Only trigger 1 KV transfer per request.
params["do_remote_prefill"] = False
def _build_save_meta(
self,
meta: NixlConnectorMetadata,
scheduler_output: SchedulerOutput,
) -> None:
# only called when use_host_buffer is True to build the save metadata
# NOTE: For the prefill side, there might be a chance that an early added
# request is a chunked prefill, so we need to check if new blocks are added
for req_id, new_block_id_groups, _ in yield_req_data(scheduler_output):
req_to_save = self._reqs_need_save.get(req_id)
if req_to_save is None or new_block_id_groups is None:
continue
req = req_to_save
assert req.kv_transfer_params is not None
clipped_block_id_groups = self.get_sw_clipped_blocks(new_block_id_groups)
meta.add_new_req_to_save(
request_id=req_id,
local_block_ids=clipped_block_id_groups,
kv_transfer_params=req.kv_transfer_params,
)
assert scheduler_output.num_scheduled_tokens is not None
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
is_partial = (
req.num_computed_tokens + num_scheduled_tokens
) < req.num_prompt_tokens
if not is_partial:
# For non-partial prefills, once new req_meta is scheduled, it
# can be removed from _reqs_need_save.
# For partial prefill case, we will retain the request in
# _reqs_need_save until all blocks are scheduled with req_meta.
# Therefore, only pop if `not is_partial`.
self._reqs_need_save.pop(req_id)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
meta = NixlConnectorMetadata()
# Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
meta.add_new_req_to_recv(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
)
if self.use_host_buffer:
self._build_save_meta(meta, scheduler_output)
meta.reqs_to_send = self._reqs_need_send
meta.reqs_in_batch = self._reqs_in_batch
meta.reqs_not_processed = self._reqs_not_processed
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
self._reqs_in_batch = set()
self._reqs_not_processed = set()
self._reqs_need_send = {}
return meta
def request_finished(
self,
request: "Request",
block_ids: BlockIds,
) -> tuple[bool, dict[str, Any] | None]:
"""
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
"""
from vllm.v1.request import RequestStatus
params = request.kv_transfer_params
logger.debug(
"NIXLConnector request_finished(%s), request_status=%s, "
"kv_transfer_params=%s",
request.request_id,
request.status,
params,
)
if not params:
return False, None
if params.get("do_remote_prefill"):
# If do_remote_prefill is still True when the request is finished,
# update_state_after_alloc must not have been called (the request
# must have been aborted before it was scheduled).
# To avoid stranding the prefill blocks in the prefill instance,
# we must add empty block_ids to _reqs_need_recv so that our
# worker side will notify and free blocks in the prefill instance.
self._reqs_need_recv[request.request_id] = (request, [])
params["do_remote_prefill"] = False
return False, None
if not params.get("do_remote_decode"):
return False, None
if request.status != RequestStatus.FINISHED_LENGTH_CAPPED:
# Also include the case of a P/D Prefill request with immediate
# block free (eg abort). Stop tracking this request.
self._reqs_not_processed.add(request.request_id)
# Clear _reqs_need_save if a request is aborted as partial prefill.
self._reqs_need_save.pop(request.request_id, None)
return False, None
# TODO: check whether block_ids actually ever be 0. If not we could
# remove the conditional below
delay_free_blocks = any(len(group) > 0 for group in block_ids)
if delay_free_blocks:
# Prefill request on remote. It will be read from D upon completion
logger.debug(
"NIXLConnector request_finished(%s) waiting for %d seconds "
"for remote decode to fetch blocks",
request.request_id,
envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT,
)
self._reqs_need_send[request.request_id] = (
time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
)
# NOTE HMA will "mark" empty/null blocks in groups with 0s (eg SWA ones),
# trimming down after allocating for the whole sequence length. Empty
# blocks are always at the start of the list.
# Here we "unpad" blocks to send the actual remote blocks to be read.
block_ids = self.get_sw_clipped_blocks(block_ids)
return delay_free_blocks, dict(
do_remote_prefill=True,
do_remote_decode=False,
remote_block_ids=block_ids,
remote_engine_id=self.engine_id,
remote_request_id=request.request_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
tp_size=self.vllm_config.parallel_config.tensor_parallel_size,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Stats and Prometheus metrics for the NIXL connector."""
import copy
from dataclasses import dataclass
from typing import Any
import numpy as np
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.utils import (
nixlXferTelemetry,
)
from vllm.v1.metrics.utils import create_metric_per_engine
@dataclass
class NixlKVConnectorStats(KVConnectorStats):
"""Container for transfer performance metrics"""
def __post_init__(self):
if not self.data:
# Empty container init, no data is passed in.
self.reset()
def reset(self):
# Must be serializable
self.data: dict[str, list[float | int]] = {
"transfer_duration": [],
"post_duration": [],
"bytes_transferred": [],
"num_descriptors": [],
"num_failed_transfers": [],
"num_failed_notifications": [],
"num_kv_expired_reqs": [],
}
def record_transfer(self, res: nixlXferTelemetry):
# Keep metrics units consistent with rest of the code: time us->s
self.data["transfer_duration"].append(res.xferDuration / 1e6)
self.data["post_duration"].append(res.postDuration / 1e6)
self.data["bytes_transferred"].append(res.totalBytes)
self.data["num_descriptors"].append(res.descCount)
def record_failed_transfer(self):
"""Record a failed NIXL transfer operation."""
self.data["num_failed_transfers"].append(1)
def record_failed_notification(self):
"""Record a failed NIXL notification (send_notif)."""
self.data["num_failed_notifications"].append(1)
def record_kv_expired_req(self):
"""Record a request that had its KV blocks expire."""
self.data["num_kv_expired_reqs"].append(1)
def clone_and_reset(self) -> "NixlKVConnectorStats":
old = copy.copy(self)
self.reset()
return old
def is_empty(self) -> bool:
# Do not discard metrics update that are entirely failures related.
return (
self.num_successful_transfers == 0
and len(self.data["num_failed_transfers"]) == 0
and len(self.data["num_failed_notifications"]) == 0
and len(self.data["num_kv_expired_reqs"]) == 0
)
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
if not other.is_empty():
for k, v in other.data.items():
accumulator = self.data[k]
assert isinstance(accumulator, list)
accumulator.extend(v)
return self
def reduce(self) -> dict[str, int | float]:
# Compute compact representative stats suitable for CLI logging
if self.num_successful_transfers == 0:
# CLI logging only reports successful transfers stats. If all requests in
# the interval were unsuccessful, Prom will report failures stats instead.
return {
"Num successful transfers": 0,
"Avg xfer time (ms)": 0,
"P90 xfer time (ms)": 0,
"Avg post time (ms)": 0,
"P90 post time (ms)": 0,
"Avg MB per transfer": 0,
"Throughput (MB/s)": 0,
"Avg number of descriptors": 0,
}
xfer_time = np.asarray(self.data["transfer_duration"])
post_time = np.asarray(self.data["post_duration"])
# Convert to MB for CLI logging.
mb = np.asarray(self.data["bytes_transferred"]) / 2**20
descs = np.asarray(self.data["num_descriptors"], dtype=np.uint32)
n = len(descs)
assert n == self.num_successful_transfers
total_mb = mb.sum()
avg_mb = total_mb / n
total_time_seconds = xfer_time.sum()
throughput_mb_s = total_mb / total_time_seconds
return {
"Num successful transfers": n,
"Avg xfer time (ms)": round(xfer_time.mean() * 1e3, 3),
"P90 xfer time (ms)": round(np.percentile(xfer_time, 90).item() * 1e3, 3),
"Avg post time (ms)": round(post_time.mean() * 1e3, 3),
"P90 post time (ms)": round(np.percentile(post_time, 90).item() * 1e3, 3),
"Avg MB per transfer": round(avg_mb, 3),
"Throughput (MB/s)": round(throughput_mb_s, 3),
"Avg number of descriptors": round(descs.mean(), 1),
}
@property
def num_successful_transfers(self) -> int:
return len(self.data["transfer_duration"])
class NixlPromMetrics(KVConnectorPromMetrics):
def __init__(
self,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]],
):
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
buckets = [
0.001,
0.005,
0.01,
0.025,
0.05,
0.075,
0.1,
0.2,
0.3,
0.5,
0.75,
1.0,
5.0,
]
nixl_histogram_xfer_time = self._histogram_cls(
name="vllm:nixl_xfer_time_seconds",
documentation="Histogram of transfer duration for NIXL KV Cache transfers.",
buckets=buckets[1:],
labelnames=labelnames,
)
self.nixl_histogram_xfer_time = create_metric_per_engine(
nixl_histogram_xfer_time, self.per_engine_labelvalues
)
nixl_histogram_post_time = self._histogram_cls(
name="vllm:nixl_post_time_seconds",
documentation="Histogram of transfer post time for NIXL KV"
" Cache transfers.",
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_post_time = create_metric_per_engine(
nixl_histogram_post_time, self.per_engine_labelvalues
)
# uniform 2kb to 16gb range
buckets = [2 ** (10 + i) for i in range(1, 25, 2)]
nixl_histogram_bytes_transferred = self._histogram_cls(
name="vllm:nixl_bytes_transferred",
documentation="Histogram of bytes transferred per NIXL KV Cache transfers.",
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_bytes_transferred = create_metric_per_engine(
nixl_histogram_bytes_transferred, self.per_engine_labelvalues
)
buckets = [
10,
20,
30,
50,
75,
100,
200,
400,
1000,
2000,
4000,
10000,
20000,
50000,
]
nixl_histogram_num_descriptors = self._histogram_cls(
name="vllm:nixl_num_descriptors",
documentation="Histogram of number of descriptors per NIXL"
" KV Cache transfers.",
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_num_descriptors = create_metric_per_engine(
nixl_histogram_num_descriptors, self.per_engine_labelvalues
)
counter_nixl_num_failed_transfers = self._counter_cls(
name="vllm:nixl_num_failed_transfers",
documentation="Number of failed NIXL KV Cache transfers.",
labelnames=labelnames,
)
self.counter_nixl_num_failed_transfers = create_metric_per_engine(
counter_nixl_num_failed_transfers, self.per_engine_labelvalues
)
counter_nixl_num_failed_notifications = self._counter_cls(
name="vllm:nixl_num_failed_notifications",
documentation="Number of failed NIXL KV Cache notifications.",
labelnames=labelnames,
)
self.counter_nixl_num_failed_notifications = create_metric_per_engine(
counter_nixl_num_failed_notifications, self.per_engine_labelvalues
)
counter_nixl_num_kv_expired_reqs = self._counter_cls(
name="vllm:nixl_num_kv_expired_reqs",
documentation="Number of requests that had their KV expire. "
"NOTE: This metric is tracked on the P instance.",
labelnames=labelnames,
)
self.counter_nixl_num_kv_expired_reqs = create_metric_per_engine(
counter_nixl_num_kv_expired_reqs, self.per_engine_labelvalues
)
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
for prom_obj, list_item_key in zip(
[
self.nixl_histogram_xfer_time,
self.nixl_histogram_post_time,
self.nixl_histogram_bytes_transferred,
self.nixl_histogram_num_descriptors,
],
[
"transfer_duration",
"post_duration",
"bytes_transferred",
"num_descriptors",
],
):
for list_item in transfer_stats_data[list_item_key]:
prom_obj[engine_idx].observe(list_item)
for counter_obj, counter_item_key in zip(
[
self.counter_nixl_num_failed_transfers,
self.counter_nixl_num_failed_notifications,
self.counter_nixl_num_kv_expired_reqs,
],
["num_failed_transfers", "num_failed_notifications", "num_kv_expired_reqs"],
):
for list_item in transfer_stats_data[counter_item_key]:
counter_obj[engine_idx].inc(list_item)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Shared constants, lazy imports and helpers for the NIXL connector."""
import contextlib
import os
import sys
from collections.abc import Iterator
from typing import Any
import zmq
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.network_utils import make_zmq_socket
logger = init_logger(__name__)
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
try:
if "UCX_RCACHE_MAX_UNRELEASED" not in os.environ:
# avoid a memory leak in UCX when using NIXL on some models
# see: https://github.com/vllm-project/vllm/issues/24264
if "nixl" in sys.modules or "rixl" in sys.modules:
logger.warning(
"NIXL was already imported, we can't reset UCX_RCACHE_MAX_UNRELEASED. "
"Please set it to '1024' manually."
)
else:
logger.info(
"Setting UCX_RCACHE_MAX_UNRELEASED to '1024' to avoid a rare "
"memory leak in UCX when using NIXL."
)
os.environ["UCX_RCACHE_MAX_UNRELEASED"] = "1024"
if not current_platform.is_rocm():
from nixl._api import nixl_agent as NixlWrapper
from nixl._bindings import nixlXferTelemetry
else:
from rixl._api import nixl_agent as NixlWrapper
from rixl._bindings import nixlXferTelemetry
logger.info("NIXL is available")
except ImportError:
logger.warning("NIXL is not available")
NixlWrapper = None
nixlXferTelemetry = None
try:
if not current_platform.is_rocm():
from nixl._api import nixl_agent_config
else:
from rixl._api import nixl_agent_config
except ImportError:
nixl_agent_config = None
logger.warning("NIXL agent config is not available")
# Supported platforms and types of kv transfer buffer.
# {device: tuple of supported kv buffer types}
_NIXL_SUPPORTED_DEVICE = {
"cuda": (
"cuda",
"cpu",
),
"tpu": ("cpu",),
"xpu": (
"cpu",
"xpu",
),
"cpu": ("cpu",),
}
# support for oot platform by providing mapping in current_platform
_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
# TODO: merge with vllm.utils.network_utils.zmq_socket_ctx
@contextlib.contextmanager
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket"""
if socket_type not in (zmq.ROUTER, zmq.REQ):
raise ValueError(f"Unexpected socket type: {socket_type}")
ctx: zmq.Context | None = None
try:
ctx = zmq.Context() # type: ignore[attr-defined]
yield make_zmq_socket(
ctx=ctx, path=addr, socket_type=socket_type, bind=socket_type == zmq.ROUTER
)
finally:
if ctx is not None:
ctx.destroy(linger=0)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib """Worker-side logic for the NIXL connector."""
import copy
import logging import logging
import os import os
import queue import queue
import sys
import threading import threading
import time import time
import uuid import uuid
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterator
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
import msgspec import msgspec
...@@ -21,32 +18,36 @@ import torch ...@@ -21,32 +18,36 @@ import torch
import zmq import zmq
from vllm import envs from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import ( from vllm.distributed.kv_transfer.kv_connector.utils import (
BlockIds, BlockIds,
EngineId, EngineId,
HeteroTPTransferConfig, HeteroTPTransferConfig,
TpKVTopology, TpKVTopology,
get_current_attn_backend,
get_current_attn_backends, get_current_attn_backends,
kv_postprocess_blksize_and_layout_on_receive, kv_postprocess_blksize_and_layout_on_receive,
kv_postprocess_blksize_on_receive, kv_postprocess_blksize_on_receive,
kv_postprocess_layout_on_receive, kv_postprocess_layout_on_receive,
yield_req_data,
) )
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import CopyBlocksOp
CopyBlocksOp, from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
KVConnectorBase_V1, from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import (
KVConnectorHandshakeMetadata, GET_META_MSG,
KVConnectorMetadata, NixlAgentMetadata,
KVConnectorRole, NixlConnectorMetadata,
SupportsHMA, NixlHandshakePayload,
ReqId,
ReqMeta,
TransferHandle,
compute_nixl_compatibility_hash,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.stats import (
NixlKVConnectorStats,
) )
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl.utils import (
KVConnectorPromMetrics, _NIXL_SUPPORTED_DEVICE,
KVConnectorStats, NixlWrapper,
PromMetric, nixl_agent_config,
PromMetricT, zmq_ctx,
) )
from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import ( from vllm.distributed.kv_transfer.kv_connector.v1.ssm_conv_transfer_utils import (
MambaConvSplitInfo, MambaConvSplitInfo,
...@@ -57,961 +58,34 @@ from vllm.distributed.parallel_state import ( ...@@ -57,961 +58,34 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_utils import is_conv_state_dim_first from vllm.model_executor.layers.mamba.mamba_utils import is_conv_state_dim_first
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv from vllm.utils.network_utils import make_zmq_path
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
FullAttentionSpec, FullAttentionSpec,
MambaSpec, MambaSpec,
SlidingWindowSpec,
UniformTypeKVCacheSpecs, UniformTypeKVCacheSpecs,
) )
from vllm.v1.metrics.utils import create_metric_per_engine
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.utils import select_common_block_size from vllm.v1.worker.utils import select_common_block_size
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.config import VllmConfig
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
TransferHandle = int
ReqId = str
#
# NIXL Connector Version
#
# Increment this version whenever there is an incompatible change to:
# - NixlAgentMetadata schema
# - kv_transfer_params schema or semantics
# - NIXL transfer protocol or wire format
# - KV cache memory layout or block organization
# - Any other change that breaks P/D interoperability
#
# Version History:
# 1: Initial version with compatibility checking
# 2: Add remote_request_id to kv_transfer_params
#
NIXL_CONNECTOR_VERSION: int = 2
GET_META_MSG = b"get_meta_msg"
logger = init_logger(__name__) logger = init_logger(__name__)
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
try:
if "UCX_RCACHE_MAX_UNRELEASED" not in os.environ:
# avoid a memory leak in UCX when using NIXL on some models
# see: https://github.com/vllm-project/vllm/issues/24264
if "nixl" in sys.modules or "rixl" in sys.modules:
logger.warning(
"NIXL was already imported, we can't reset UCX_RCACHE_MAX_UNRELEASED. "
"Please set it to '1024' manually."
)
else:
logger.info(
"Setting UCX_RCACHE_MAX_UNRELEASED to '1024' to avoid a rare "
"memory leak in UCX when using NIXL."
)
os.environ["UCX_RCACHE_MAX_UNRELEASED"] = "1024"
if not current_platform.is_rocm():
from nixl._api import nixl_agent as NixlWrapper
from nixl._bindings import nixlXferTelemetry
else:
from rixl._api import nixl_agent as NixlWrapper
from rixl._bindings import nixlXferTelemetry
logger.info("NIXL is available")
except ImportError:
logger.warning("NIXL is not available")
NixlWrapper = None
nixlXferTelemetry = None
try:
if not current_platform.is_rocm():
from nixl._api import nixl_agent_config
else:
from rixl._api import nixl_agent_config
except ImportError:
nixl_agent_config = None
logger.warning("NIXL agent config is not available")
# Supported platforms and types of kv transfer buffer.
# {device: tuple of supported kv buffer types}
_NIXL_SUPPORTED_DEVICE = {
"cuda": (
"cuda",
"cpu",
),
"tpu": ("cpu",),
"xpu": (
"cpu",
"xpu",
),
"cpu": ("cpu",),
}
# support for oot platform by providing mapping in current_platform
_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
@dataclass
class NixlAgentMetadata:
engine_id: str
agent_metadata: bytes
kv_caches_base_addr: list[int]
device_id: int
num_blocks: int
block_lens: list[int]
kv_cache_layout: str
block_size: int
ssm_sizes: tuple[int, int]
attn_backend_name: str
@dataclass
class NixlHandshakePayload(KVConnectorHandshakeMetadata):
"""
Wrapper for NIXL handshake sent over the wire.
Enables two-phase decoding for graceful compatibility checking:
1. Decode NixlHandshakePayload to get compatibility_hash
2. Compute local hash and compare
3. Only if hashes match, decode agent_metadata_bytes
This prevents decoder errors when NixlAgentMetadata schema is
incompatible, allowing graceful failure with clear error message.
"""
compatibility_hash: str
agent_metadata_bytes: bytes # NixlAgentMetadata encoded
def compute_nixl_compatibility_hash(
vllm_config: VllmConfig, attn_backend_name: str, cross_layers_blocks: bool
) -> str:
"""
Compute compatibility hash for NIXL KV transfer.
Hash only the factors that affect whether two NIXL instances can
successfully transfer KV cache data.
Factors included:
- vLLM version and NIXL connector version
- Model architecture (name, dtype, KV heads, layers)
- KV cache format (dtype, sliding window)
- Attention backend
Note: Factors like tensor_parallel_size, block_size, and kv_cache_layout
are validated at runtime in _validate_remote_agent_handshake and are not
included in this hash to support heterogeneous deployments.
Note - the set of factors are likely to evolve significantly over
time to be more or less permissive.
Returns:
SHA-256 hex digest
"""
from vllm import __version__ as vllm_version
from vllm.config.utils import hash_factors
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
is_hma_enabled = not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager
factors = {
# Version compatibility
"vllm_version": vllm_version,
"nixl_connector_version": NIXL_CONNECTOR_VERSION,
# Model architecture - affects KV cache shape
"model": model_config.model,
"dtype": str(model_config.dtype),
"num_kv_heads": model_config.get_total_num_kv_heads(),
"head_size": model_config.get_head_size(),
"num_hidden_layers": model_config.get_total_num_hidden_layers(),
# Attention backend and KV cache dtype affect memory layout
"attn_backend_name": attn_backend_name,
"cache_dtype": str(cache_config.cache_dtype),
"cross_layers_blocks": cross_layers_blocks,
"is_hma_enabled": is_hma_enabled,
}
compat_hash = hash_factors(factors)
logger.debug(
"NIXL compatibility hash: %s (model=%s, dtype=%s, num_kv_heads=%d, "
"cache_dtype=%s, attn_backend=%s)",
compat_hash,
factors["model"],
factors["dtype"],
factors["num_kv_heads"],
factors["cache_dtype"],
attn_backend_name,
)
return compat_hash
@dataclass
class RemoteMeta:
block_ids: BlockIds
host: str
port: int
engine_id: str
request_id: str
@dataclass
class ReqMeta:
local_block_ids: BlockIds
# To be used when logical block size does not match the kernel block size
local_physical_block_ids: BlockIds
tp_size: int
remote: RemoteMeta | None = None
class NixlConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
self.reqs_to_send: dict[ReqId, float] = {}
self.reqs_in_batch: set[ReqId] = set()
self.reqs_not_processed: set[ReqId] = set()
def _add_new_req(
self,
local_block_ids: BlockIds,
kv_transfer_params: dict[str, Any],
) -> ReqMeta:
return ReqMeta(
local_block_ids=local_block_ids,
local_physical_block_ids=local_block_ids,
# P workers don't need to receive tp_size from proxy here.
tp_size=kv_transfer_params.get("tp_size", 1),
)
def add_new_req_to_save(
self,
request_id: ReqId,
local_block_ids: BlockIds,
kv_transfer_params: dict[str, Any],
):
self.reqs_to_save[request_id] = self._add_new_req(
local_block_ids, kv_transfer_params
)
def add_new_req_to_recv(
self,
request_id: ReqId,
local_block_ids: BlockIds,
kv_transfer_params: dict[str, Any],
):
req = self._add_new_req(local_block_ids, kv_transfer_params)
req.remote = RemoteMeta(
block_ids=kv_transfer_params["remote_block_ids"],
engine_id=kv_transfer_params["remote_engine_id"],
request_id=kv_transfer_params["remote_request_id"],
host=kv_transfer_params["remote_host"],
port=kv_transfer_params["remote_port"],
)
self.reqs_to_recv[request_id] = req
class NixlConnector(KVConnectorBase_V1, SupportsHMA):
@property
def prefer_cross_layer_blocks(self) -> bool:
if any(
[
isinstance(group.kv_cache_spec, MambaSpec)
for group in self.kv_cache_config.kv_cache_groups
]
):
# Hybrid SSM models do not yet support cross-layer layout
return False
backend = get_current_attn_backend(self._vllm_config)
if backend.get_name() not in (
"FLASH_ATTN",
"FLASHINFER",
"TRITON_ATTN",
):
return False
# For now there is no benefit to run cross layers when backend
# does not support on HND
if get_kv_cache_layout() != "HND":
return False
extra_config = self.kv_transfer_config.kv_connector_extra_config
return (
str(extra_config.get("enable_cross_layers_blocks", "False")).lower()
== "true"
)
def __init__(
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig",
):
super().__init__(vllm_config, role, kv_cache_config)
assert vllm_config.kv_transfer_config is not None
assert vllm_config.kv_transfer_config.engine_id is not None
self.kv_cache_config = kv_cache_config
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
self.kv_transfer_config = vllm_config.kv_transfer_config
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler: NixlConnectorScheduler | None = (
NixlConnectorScheduler(vllm_config, self.engine_id, kv_cache_config)
)
self.connector_worker: NixlConnectorWorker | None = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = NixlConnectorWorker(
vllm_config, self.engine_id, kv_cache_config
)
############################################################
# Class Methods
############################################################
@classmethod
def get_required_kvcache_layout(cls, vllm_config: VllmConfig):
if vllm_config.model_config is None:
logger.warning_once(
"Unable to detect current VLLM config. "
"Fallback to default kv cache layout."
)
return None
use_mla = vllm_config.model_config.use_mla
if use_mla:
# return None when we have mla
# as the layout should not matter in that case,
# which fallback to the default behavior.
return None
logger.info_once(
"NixlConnector setting KV cache layout to HND for better xfer performance."
)
return "HND"
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int | None, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens
)
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens
)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, (block_ids,))
def request_finished_all_groups(
self,
request: "Request",
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
"""
Set the KV connector handshake metadata for this connector.
Args:
metadata (dict): the handshake metadata to set.
"""
assert self.connector_scheduler is not None
self.connector_scheduler.set_xfer_handshake_metadata(metadata)
############################################################
# Worker Side Methods
############################################################
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
assert self.connector_worker is not None
self.connector_worker.register_cross_layers_kv_caches(kv_cache)
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
assert self.connector_worker is not None
self.connector_worker.set_host_xfer_buffer_ops(copy_operation)
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
return self.connector_worker.get_finished()
def get_block_ids_with_load_errors(self) -> set[int]:
"""Get block IDs that failed to load via NIXL."""
assert self.connector_worker is not None
return self.connector_worker.get_block_ids_with_load_errors()
def get_kv_connector_stats(self) -> KVConnectorStats | None:
if self.connector_worker is None:
return None
return self.connector_worker.get_kv_connector_stats()
@classmethod
def build_kv_connector_stats(
cls, data: dict[str, Any] | None = None
) -> KVConnectorStats | None:
return (
NixlKVConnectorStats(data=data)
if data is not None
else NixlKVConnectorStats()
)
@classmethod
def build_prom_metrics(
cls,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]],
) -> KVConnectorPromMetrics:
return NixlPromMetrics(
vllm_config, metric_types, labelnames, per_engine_labelvalues
)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, NixlConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata)
def wait_for_layer_load(self, layer_name: str) -> None:
"""NixlConnector does not do layerwise saving."""
pass
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
"""NixlConnector does not save explicitly."""
pass
def wait_for_save(self):
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, NixlConnectorMetadata)
if self.connector_worker.use_host_buffer and self.connector_worker.copy_blocks:
self.connector_worker.save_kv_to_host(self._connector_metadata)
def shutdown(self):
if self.connector_worker is not None:
self.connector_worker.shutdown()
if self.connector_scheduler is not None:
self.connector_scheduler.shutdown()
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
"""
Get the KVConnector handshake metadata for this connector.
This metadata is used for out-of-band connector handshake
between P/D workers.
Returns:
KVConnectorHandshakeMetadata: the handshake metadata.
None if no handshake metadata is available.
"""
assert self.connector_worker is not None
return self.connector_worker.xfer_handshake_metadata
class NixlConnectorScheduler:
"""Implementation of Scheduler side methods"""
def __init__(
self, vllm_config: VllmConfig, engine_id: str, kv_cache_config: "KVCacheConfig"
):
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.engine_id: EngineId = engine_id
self.kv_cache_config = kv_cache_config
self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
self.side_channel_port = (
envs.VLLM_NIXL_SIDE_CHANNEL_PORT
+ vllm_config.parallel_config.data_parallel_index
)
assert vllm_config.kv_transfer_config is not None
if current_platform.device_type == "cpu":
self.use_host_buffer = False
else:
self.use_host_buffer = (
vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
)
self._is_hma_required = (
not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager
# Also handle unlikely SW-only model case instead of checking num_groups>1.
and any(
not isinstance(g.kv_cache_spec, FullAttentionSpec)
for g in kv_cache_config.kv_cache_groups
)
)
self._has_mamba = any(
isinstance(g.kv_cache_spec, MambaSpec)
for g in kv_cache_config.kv_cache_groups
)
logger.info("Initializing NIXL Scheduler %s", engine_id)
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
logger.info("Hybrid Memory Allocator is enabled with NIXL")
# Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: threading.Thread | None = None
self._stop_event = threading.Event()
# Requests that need to start recv/send.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[ReqId, tuple[Request, BlockIds]] = {}
self._reqs_need_save: dict[ReqId, Request] = {}
# Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {}
self._reqs_in_batch: set[ReqId] = set()
# Reqs to remove from processed set because they're not to send after
# remote prefill or aborted.
self._reqs_not_processed: set[ReqId] = set()
# Gather Sliding Window sizes for each kv cache group (if any) in number of
# blocks per KV cache group. This is used to clip the local attention window.
sw_sizes_tokens: list[tuple[int, int]] = [
(g.kv_cache_spec.sliding_window, g.kv_cache_spec.block_size)
if isinstance(g.kv_cache_spec, SlidingWindowSpec)
else (0, self.block_size)
for g in kv_cache_config.kv_cache_groups
]
# cdiv(n_tokens, block_size) gives blocks/window; add 1 to conservatively
# account for boundary overlap eg window isn't fully aligned with blocks.
self.blocks_per_sw = [
cdiv(n_tokens, block_size) + 1 if n_tokens else 0
for n_tokens, block_size in sw_sizes_tokens
]
def shutdown(self):
self._stop_event.set()
if self._nixl_handshake_listener_t is not None:
self._nixl_handshake_listener_t.join()
self._nixl_handshake_listener_t = None
def get_sw_clipped_blocks(self, block_ids: BlockIds) -> BlockIds:
"""
Clip the number of blocks to the sliding window size for each kv cache group
that employs SWA.
This is necessary because the KV Cache manager initially allocates blocks for
the entire sequence length, and successively cleans up blocks that are outside
the window prior to the `request_finished_all_groups` hook.
"""
if len(block_ids) == 0 or not self._is_hma_required:
# No blocks to clip eg Full prefix cache hit or not a hybrid model.
return block_ids
# NOTE (NickLucche) This logic is currently handled at the connector level
# because offloading connectors might want to receive the whole sequence even
# for SWA groups. We will abstract this logic once the interface is more stable
assert len(block_ids) == len(self.blocks_per_sw), (
"Number of KV cache groups must match"
)
# For non-SWA groups, blocks_per_sw is 0 so we return all block_ids unchanged
return tuple(
[
blocks[-self.blocks_per_sw[i] :]
if self.blocks_per_sw[i] > 0
else blocks
for i, blocks in enumerate(block_ids)
]
)
def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
"""
Set the KV connector handshake metadata for this connector.
Args:
metadata (dict): the handshake metadata to set.
"""
encoded_data: dict[int, bytes] = {}
encoder = msgspec.msgpack.Encoder()
for tp_rank, rank_metadata in metadata.items():
if not isinstance(rank_metadata, NixlHandshakePayload):
raise ValueError(
"NixlConnectorScheduler expects NixlHandshakePayload for "
"handshake metadata."
)
encoded_data[tp_rank] = encoder.encode(rank_metadata)
logger.debug(
"Tp rank %d: encoded NixlHandshakePayload size: %s bytes",
tp_rank,
str(len(encoded_data[tp_rank])),
)
# Only start the listener when we have metadata to serve.
if self._nixl_handshake_listener_t is None:
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
args=(
encoded_data,
ready_event,
self._stop_event,
self.side_channel_port,
),
daemon=True,
name="nixl_handshake_listener",
)
self._nixl_handshake_listener_t.start()
ready_event.wait() # Wait for listener ZMQ socket to be ready.
@staticmethod
def _nixl_handshake_listener(
encoded_data: dict[int, Any],
ready_event: threading.Event,
stop_event: threading.Event,
port: int,
):
"""Background thread for getting new NIXL handshakes."""
# NOTE(rob): this is a simple implementation. We will move
# to a better approach via HTTP endpoint soon.
# Listen for new requests for metadata.
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
path = make_zmq_path("tcp", host, port)
logger.debug("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock:
sock.setsockopt(zmq.RCVTIMEO, 1000)
ready_event.set()
while True:
try:
identity, _, msg = sock.recv_multipart()
except zmq.Again:
if stop_event.is_set():
break
continue
# Decode the message which contains (GET_META_MSG, rank)
msg, target_tp_rank = msgspec.msgpack.decode(msg)
logger.debug(
"Received message for tp rank %s",
target_tp_rank,
)
if msg != GET_META_MSG:
logger.warning("Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data[target_tp_rank]))
def _mamba_prefill_token_count(self, num_prompt_tokens: int) -> int:
"""D-side only. Returns N-1 for Mamba models since the decoder
always recomputes the last token and must start from h(N-1)."""
if self._has_mamba and num_prompt_tokens > 1:
return num_prompt_tokens - 1
return num_prompt_tokens
def _truncate_mamba_request_for_prefill(self, request: "Request") -> None:
"""P-side only: drop the last prompt token so the prefiller computes
h(N-1) instead of h(N). The decoder recomputes the last token to
derive h(N) correctly.
Guarded by ``_p_side_truncated`` to avoid repeated truncation if the
request is preempted and rescheduled."""
params = request.kv_transfer_params
if (
params is not None
# Guard against repeated truncation after preemption/reschedule.
and not params.get("_p_side_truncated")
and request.num_prompt_tokens > 1
):
if request.prompt_token_ids is not None:
request.prompt_token_ids.pop()
elif request.prompt_embeds is not None:
request.prompt_embeds = request.prompt_embeds[:-1]
else:
return
request._all_token_ids.pop()
request.num_prompt_tokens -= 1
request.max_tokens = 1
params["_p_side_truncated"] = True
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int, bool]:
"""
For remote prefill, pull all prompt blocks from remote
asynchronously relative to engine execution.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
* the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
* true if the external KV cache tokens will be loaded
asynchronously (between scheduler steps).
"""
params = request.kv_transfer_params
logger.debug(
"NIXLConnector get_num_new_matched_tokens: "
"num_computed_tokens=%s, kv_transfer_params=%s",
num_computed_tokens,
params,
)
if params is not None and params.get("do_remote_prefill"):
# Remote prefill: get all prompt blocks from remote.
token_ids = request.prompt_token_ids or []
actual = self._mamba_prefill_token_count(len(token_ids))
count = actual - num_computed_tokens
if count > 0:
return count, True
if params is not None and params.get("do_remote_decode") and self._has_mamba:
self._truncate_mamba_request_for_prefill(request)
# No remote prefill for this request.
return 0, False
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
params = request.kv_transfer_params
logger.debug(
"NIXLConnector update_state_after_alloc: "
"num_external_tokens=%s, kv_transfer_params=%s",
num_external_tokens,
params,
)
if not params:
return
if params.get("do_remote_decode"):
self._reqs_in_batch.add(request.request_id)
if self.use_host_buffer and params.get("do_remote_decode"):
# NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer.
self._reqs_need_save[request.request_id] = request
elif params.get("do_remote_prefill"):
if params.get("remote_block_ids"):
if all(
p in params
for p in (
"remote_engine_id",
"remote_request_id",
"remote_host",
"remote_port",
)
):
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
# send_notif in _read_blocks to free the memory on the P.
unhashed_local_block_ids: BlockIds = (
blocks.get_unhashed_block_ids_all_groups()
if num_external_tokens > 0
else ()
)
local_block_ids = self.get_sw_clipped_blocks(
unhashed_local_block_ids
)
# Get unhashed blocks to pull from remote. Mind that a full prefix
# cache hit is indicated with an empty list.
self._reqs_need_recv[request.request_id] = (
request,
local_block_ids,
)
else:
logger.warning(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer",
params,
)
else:
assert num_external_tokens == 0
# Only trigger 1 KV transfer per request.
params["do_remote_prefill"] = False
def _build_save_meta(
self,
meta: NixlConnectorMetadata,
scheduler_output: SchedulerOutput,
) -> None:
# only called when use_host_buffer is True to build the save metadata
# NOTE: For the prefill side, there might be a chance that an early added
# request is a chunked prefill, so we need to check if new blocks are added
for req_id, new_block_id_groups, _ in yield_req_data(scheduler_output):
req_to_save = self._reqs_need_save.get(req_id)
if req_to_save is None or new_block_id_groups is None:
continue
req = req_to_save
assert req.kv_transfer_params is not None
clipped_block_id_groups = self.get_sw_clipped_blocks(new_block_id_groups)
meta.add_new_req_to_save(
request_id=req_id,
local_block_ids=clipped_block_id_groups,
kv_transfer_params=req.kv_transfer_params,
)
assert scheduler_output.num_scheduled_tokens is not None
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
is_partial = (
req.num_computed_tokens + num_scheduled_tokens
) < req.num_prompt_tokens
if not is_partial:
# For non-partial prefills, once new req_meta is scheduled, it
# can be removed from _reqs_need_save.
# For partial prefill case, we will retain the request in
# _reqs_need_save until all blocks are scheduled with req_meta.
# Therefore, only pop if `not is_partial`.
self._reqs_need_save.pop(req_id)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
meta = NixlConnectorMetadata()
# Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
meta.add_new_req_to_recv(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
)
if self.use_host_buffer:
self._build_save_meta(meta, scheduler_output)
meta.reqs_to_send = self._reqs_need_send
meta.reqs_in_batch = self._reqs_in_batch
meta.reqs_not_processed = self._reqs_not_processed
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
self._reqs_in_batch = set()
self._reqs_not_processed = set()
self._reqs_need_send = {}
return meta
def request_finished(
self,
request: "Request",
block_ids: BlockIds,
) -> tuple[bool, dict[str, Any] | None]:
"""
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
"""
from vllm.v1.request import RequestStatus
params = request.kv_transfer_params
logger.debug(
"NIXLConnector request_finished(%s), request_status=%s, "
"kv_transfer_params=%s",
request.request_id,
request.status,
params,
)
if not params:
return False, None
if params.get("do_remote_prefill"):
# If do_remote_prefill is still True when the request is finished,
# update_state_after_alloc must not have been called (the request
# must have been aborted before it was scheduled).
# To avoid stranding the prefill blocks in the prefill instance,
# we must add empty block_ids to _reqs_need_recv so that our
# worker side will notify and free blocks in the prefill instance.
self._reqs_need_recv[request.request_id] = (request, [])
params["do_remote_prefill"] = False
return False, None
if not params.get("do_remote_decode"):
return False, None
if request.status != RequestStatus.FINISHED_LENGTH_CAPPED:
# Also include the case of a P/D Prefill request with immediate
# block free (eg abort). Stop tracking this request.
self._reqs_not_processed.add(request.request_id)
# Clear _reqs_need_save if a request is aborted as partial prefill.
self._reqs_need_save.pop(request.request_id, None)
return False, None
# TODO: check whether block_ids actually ever be 0. If not we could
# remove the conditional below
delay_free_blocks = any(len(group) > 0 for group in block_ids)
if delay_free_blocks:
# Prefill request on remote. It will be read from D upon completion
logger.debug(
"NIXLConnector request_finished(%s) waiting for %d seconds "
"for remote decode to fetch blocks",
request.request_id,
envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT,
)
self._reqs_need_send[request.request_id] = (
time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
)
# NOTE HMA will "mark" empty/null blocks in groups with 0s (eg SWA ones),
# trimming down after allocating for the whole sequence length. Empty
# blocks are always at the start of the list.
# Here we "unpad" blocks to send the actual remote blocks to be read.
block_ids = self.get_sw_clipped_blocks(block_ids)
return delay_free_blocks, dict(
do_remote_prefill=True,
do_remote_decode=False,
remote_block_ids=block_ids,
remote_engine_id=self.engine_id,
remote_request_id=request.request_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
tp_size=self.vllm_config.parallel_config.tensor_parallel_size,
)
class NixlConnectorWorker: class NixlConnectorWorker:
"""Implementation of Worker side methods""" """Implementation of Worker side methods"""
def __init__( def __init__(
self, vllm_config: VllmConfig, engine_id: str, kv_cache_config: "KVCacheConfig" self,
vllm_config: "VllmConfig",
engine_id: str,
kv_cache_config: "KVCacheConfig",
): ):
if NixlWrapper is None: if NixlWrapper is None:
logger.error("NIXL is not available") logger.error("NIXL is not available")
...@@ -3200,9 +2274,9 @@ class NixlConnectorWorker: ...@@ -3200,9 +2274,9 @@ class NixlConnectorWorker:
the their size differs. the their size differs.
Reference diagram: Reference diagram:
KVCacheTensor (Shared) KVCacheTensor (Shared)
/ \ / \\
/ \ / \\
/ \ / \\
Attention (FlashInfer) View Mamba View Attention (FlashInfer) View Mamba View
| | | |
| | | |
...@@ -3283,266 +2357,3 @@ class NixlConnectorWorker: ...@@ -3283,266 +2357,3 @@ class NixlConnectorWorker:
for desc in self._registered_descs: for desc in self._registered_descs:
self.nixl_wrapper.deregister_memory(desc) self.nixl_wrapper.deregister_memory(desc)
self._registered_descs.clear() self._registered_descs.clear()
@contextlib.contextmanager
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket"""
if socket_type not in (zmq.ROUTER, zmq.REQ):
raise ValueError(f"Unexpected socket type: {socket_type}")
ctx: zmq.Context | None = None
try:
ctx = zmq.Context() # type: ignore[attr-defined]
yield make_zmq_socket(
ctx=ctx, path=addr, socket_type=socket_type, bind=socket_type == zmq.ROUTER
)
finally:
if ctx is not None:
ctx.destroy(linger=0)
@dataclass
class NixlKVConnectorStats(KVConnectorStats):
"""Container for transfer performance metrics"""
def __post_init__(self):
if not self.data:
# Empty container init, no data is passed in.
self.reset()
def reset(self):
# Must be serializable
self.data: dict[str, list[float | int]] = {
"transfer_duration": [],
"post_duration": [],
"bytes_transferred": [],
"num_descriptors": [],
"num_failed_transfers": [],
"num_failed_notifications": [],
"num_kv_expired_reqs": [],
}
def record_transfer(self, res: nixlXferTelemetry):
# Keep metrics units consistent with rest of the code: time us->s
self.data["transfer_duration"].append(res.xferDuration / 1e6)
self.data["post_duration"].append(res.postDuration / 1e6)
self.data["bytes_transferred"].append(res.totalBytes)
self.data["num_descriptors"].append(res.descCount)
def record_failed_transfer(self):
"""Record a failed NIXL transfer operation."""
self.data["num_failed_transfers"].append(1)
def record_failed_notification(self):
"""Record a failed NIXL notification (send_notif)."""
self.data["num_failed_notifications"].append(1)
def record_kv_expired_req(self):
"""Record a request that had its KV blocks expire."""
self.data["num_kv_expired_reqs"].append(1)
def clone_and_reset(self) -> "NixlKVConnectorStats":
old = copy.copy(self)
self.reset()
return old
def is_empty(self) -> bool:
# Do not discard metrics update that are entirely failures related.
return (
self.num_successful_transfers == 0
and len(self.data["num_failed_transfers"]) == 0
and len(self.data["num_failed_notifications"]) == 0
and len(self.data["num_kv_expired_reqs"]) == 0
)
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
if not other.is_empty():
for k, v in other.data.items():
accumulator = self.data[k]
assert isinstance(accumulator, list)
accumulator.extend(v)
return self
def reduce(self) -> dict[str, int | float]:
# Compute compact representative stats suitable for CLI logging
if self.num_successful_transfers == 0:
# CLI logging only reports successful transfers stats. If all requests in
# the interval were unsuccessful, Prom will report failures stats instead.
return {
"Num successful transfers": 0,
"Avg xfer time (ms)": 0,
"P90 xfer time (ms)": 0,
"Avg post time (ms)": 0,
"P90 post time (ms)": 0,
"Avg MB per transfer": 0,
"Throughput (MB/s)": 0,
"Avg number of descriptors": 0,
}
xfer_time = np.asarray(self.data["transfer_duration"])
post_time = np.asarray(self.data["post_duration"])
# Convert to MB for CLI logging.
mb = np.asarray(self.data["bytes_transferred"]) / 2**20
descs = np.asarray(self.data["num_descriptors"], dtype=np.uint32)
n = len(descs)
assert n == self.num_successful_transfers
total_mb = mb.sum()
avg_mb = total_mb / n
total_time_seconds = xfer_time.sum()
throughput_mb_s = total_mb / total_time_seconds
return {
"Num successful transfers": n,
"Avg xfer time (ms)": round(xfer_time.mean() * 1e3, 3),
"P90 xfer time (ms)": round(np.percentile(xfer_time, 90).item() * 1e3, 3),
"Avg post time (ms)": round(post_time.mean() * 1e3, 3),
"P90 post time (ms)": round(np.percentile(post_time, 90).item() * 1e3, 3),
"Avg MB per transfer": round(avg_mb, 3),
"Throughput (MB/s)": round(throughput_mb_s, 3),
"Avg number of descriptors": round(descs.mean(), 1),
}
@property
def num_successful_transfers(self) -> int:
return len(self.data["transfer_duration"])
class NixlPromMetrics(KVConnectorPromMetrics):
def __init__(
self,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]],
):
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
buckets = [
0.001,
0.005,
0.01,
0.025,
0.05,
0.075,
0.1,
0.2,
0.3,
0.5,
0.75,
1.0,
5.0,
]
nixl_histogram_xfer_time = self._histogram_cls(
name="vllm:nixl_xfer_time_seconds",
documentation="Histogram of transfer duration for NIXL KV Cache transfers.",
buckets=buckets[1:],
labelnames=labelnames,
)
self.nixl_histogram_xfer_time = create_metric_per_engine(
nixl_histogram_xfer_time, self.per_engine_labelvalues
)
nixl_histogram_post_time = self._histogram_cls(
name="vllm:nixl_post_time_seconds",
documentation="Histogram of transfer post time for NIXL KV"
" Cache transfers.",
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_post_time = create_metric_per_engine(
nixl_histogram_post_time, self.per_engine_labelvalues
)
# uniform 2kb to 16gb range
buckets = [2 ** (10 + i) for i in range(1, 25, 2)]
nixl_histogram_bytes_transferred = self._histogram_cls(
name="vllm:nixl_bytes_transferred",
documentation="Histogram of bytes transferred per NIXL KV Cache transfers.",
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_bytes_transferred = create_metric_per_engine(
nixl_histogram_bytes_transferred, self.per_engine_labelvalues
)
buckets = [
10,
20,
30,
50,
75,
100,
200,
400,
1000,
2000,
4000,
10000,
20000,
50000,
]
nixl_histogram_num_descriptors = self._histogram_cls(
name="vllm:nixl_num_descriptors",
documentation="Histogram of number of descriptors per NIXL"
" KV Cache transfers.",
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_num_descriptors = create_metric_per_engine(
nixl_histogram_num_descriptors, self.per_engine_labelvalues
)
counter_nixl_num_failed_transfers = self._counter_cls(
name="vllm:nixl_num_failed_transfers",
documentation="Number of failed NIXL KV Cache transfers.",
labelnames=labelnames,
)
self.counter_nixl_num_failed_transfers = create_metric_per_engine(
counter_nixl_num_failed_transfers, self.per_engine_labelvalues
)
counter_nixl_num_failed_notifications = self._counter_cls(
name="vllm:nixl_num_failed_notifications",
documentation="Number of failed NIXL KV Cache notifications.",
labelnames=labelnames,
)
self.counter_nixl_num_failed_notifications = create_metric_per_engine(
counter_nixl_num_failed_notifications, self.per_engine_labelvalues
)
counter_nixl_num_kv_expired_reqs = self._counter_cls(
name="vllm:nixl_num_kv_expired_reqs",
documentation="Number of requests that had their KV expire. "
"NOTE: This metric is tracked on the P instance.",
labelnames=labelnames,
)
self.counter_nixl_num_kv_expired_reqs = create_metric_per_engine(
counter_nixl_num_kv_expired_reqs, self.per_engine_labelvalues
)
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
for prom_obj, list_item_key in zip(
[
self.nixl_histogram_xfer_time,
self.nixl_histogram_post_time,
self.nixl_histogram_bytes_transferred,
self.nixl_histogram_num_descriptors,
],
[
"transfer_duration",
"post_duration",
"bytes_transferred",
"num_descriptors",
],
):
for list_item in transfer_stats_data[list_item_key]:
prom_obj[engine_idx].observe(list_item)
for counter_obj, counter_item_key in zip(
[
self.counter_nixl_num_failed_transfers,
self.counter_nixl_num_failed_notifications,
self.counter_nixl_num_kv_expired_reqs,
],
["num_failed_transfers", "num_failed_notifications", "num_kv_expired_reqs"],
):
for list_item in transfer_stats_data[counter_item_key]:
counter_obj[engine_idx].inc(list_item)
...@@ -114,7 +114,7 @@ def derive_mamba_conv_split( ...@@ -114,7 +114,7 @@ def derive_mamba_conv_split(
assert len(conv_shape) == 2, f"Expected 2D conv state shape, got {conv_shape}" assert len(conv_shape) == 2, f"Expected 2D conv state shape, got {conv_shape}"
# NOTE (ZhanqiuHu): 3-read requires DS layout, which is already asserted # NOTE (ZhanqiuHu): 3-read requires DS layout, which is already asserted
# in nixl_connector __init__. Use it directly instead of heuristic detection. # in nixl worker __init__. Use it directly instead of heuristic detection.
assert is_conv_state_dim_first(), "3-read requires DS conv state layout" assert is_conv_state_dim_first(), "3-read requires DS conv state layout"
local_conv_dim = conv_shape[0] # DS: (conv_dim_local, conv_rows) local_conv_dim = conv_shape[0] # DS: (conv_dim_local, conv_rows)
conv_rows = conv_shape[1] conv_rows = conv_shape[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment