Unverified Commit 3244a2eb authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[KVConnector][NIXL] Organize NIXL connector into its own directory (#39354)



The number of features supported by the connector has grown substantially
and the `nixl_connector.py` file has accumulated a lot of code. Creates a separate
directory and isolates connector/scheduler code in the hope of improving clarity
and maintainability.

Further refactor of components aimed at improving clarity and simplifying code
will follow soon.
Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent 72ff142c
...@@ -72,4 +72,4 @@ For the PD disaggregation part, the Prefill instance receives cache exactly the ...@@ -72,4 +72,4 @@ For the PD disaggregation part, the Prefill instance receives cache exactly the
`docs/features/disagg_prefill.md` shows the brief idea about the disaggregated prefill (v0) `docs/features/disagg_prefill.md` shows the brief idea about the disaggregated prefill (v0)
We create the example setup with the **NixlConnector** from `vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py` and referred to the `tests/v1/kv_connector/nixl_integration/toy_proxy_server.py` to facilitate the kv transfer between P and D; We create the example setup with the **NixlConnector** from `vllm/distributed/kv_transfer/kv_connector/v1/nixl/` and referred to the `tests/v1/kv_connector/nixl_integration/toy_proxy_server.py` to facilitate the kv transfer between P and D;
...@@ -19,7 +19,7 @@ METRIC_SOURCE_FILES = [ ...@@ -19,7 +19,7 @@ METRIC_SOURCE_FILES = [
"output": "spec_decode.inc.md", "output": "spec_decode.inc.md",
}, },
{ {
"path": "vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py", "path": "vllm/distributed/kv_transfer/kv_connector/v1/nixl/stats.py",
"output": "nixl_connector.inc.md", "output": "nixl_connector.inc.md",
}, },
{"path": "vllm/v1/metrics/perf.py", "output": "perf.inc.md"}, {"path": "vllm/v1/metrics/perf.py", "output": "perf.inc.md"},
......
...@@ -21,7 +21,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( ...@@ -21,7 +21,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiKVConnectorStats, MultiKVConnectorStats,
MultiKVConnectorWorkerMetadata, MultiKVConnectorWorkerMetadata,
) )
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl import (
NixlKVConnectorStats, NixlKVConnectorStats,
) )
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
......
...@@ -24,13 +24,13 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( ...@@ -24,13 +24,13 @@ from vllm.distributed.kv_transfer.kv_connector.utils import (
TpKVTopology, TpKVTopology,
get_current_attn_backend, get_current_attn_backend,
) )
from vllm.distributed.kv_transfer.kv_connector.v1 import nixl_connector from vllm.distributed.kv_transfer.kv_connector.v1 import nixl
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiKVConnectorStats, MultiKVConnectorStats,
) )
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl import (
KVConnectorRole,
NixlAgentMetadata, NixlAgentMetadata,
NixlConnector, NixlConnector,
NixlConnectorMetadata, NixlConnectorMetadata,
...@@ -38,6 +38,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( ...@@ -38,6 +38,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorWorker, NixlConnectorWorker,
NixlHandshakePayload, NixlHandshakePayload,
NixlKVConnectorStats, NixlKVConnectorStats,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import (
compute_nixl_compatibility_hash, compute_nixl_compatibility_hash,
) )
from vllm.distributed.kv_transfer.kv_transfer_state import ( from vllm.distributed.kv_transfer.kv_transfer_state import (
...@@ -320,7 +322,7 @@ def test_prompt_less_than_block_size(): ...@@ -320,7 +322,7 @@ def test_prompt_less_than_block_size():
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_kv_transfer_handshake(dist_init): def test_kv_transfer_handshake(dist_init):
...@@ -534,7 +536,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): ...@@ -534,7 +536,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
class TestNixlHandshake: class TestNixlHandshake:
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_multi_xfer_one_engine( def test_multi_xfer_one_engine(
...@@ -621,7 +623,7 @@ class TestNixlHandshake: ...@@ -621,7 +623,7 @@ class TestNixlHandshake:
connector.clear_connector_metadata() connector.clear_connector_metadata()
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -691,7 +693,7 @@ class TestNixlHandshake: ...@@ -691,7 +693,7 @@ class TestNixlHandshake:
raise TimeoutError("Took too long to complete async handshake.") raise TimeoutError("Took too long to complete async handshake.")
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
@pytest.mark.parametrize("local_tp_size", [1, 2]) @pytest.mark.parametrize("local_tp_size", [1, 2])
...@@ -703,7 +705,7 @@ class TestNixlHandshake: ...@@ -703,7 +705,7 @@ class TestNixlHandshake:
remote configurations. remote configurations.
""" """
monkeypatch.setattr( monkeypatch.setattr(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.get_tensor_model_parallel_world_size",
lambda: local_tp_size, lambda: local_tp_size,
) )
...@@ -760,7 +762,7 @@ class TestNixlHandshake: ...@@ -760,7 +762,7 @@ class TestNixlHandshake:
check_handshake(6) check_handshake(6)
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_prefill_tp_size_greater_than_decode_tp_size_mla( def test_prefill_tp_size_greater_than_decode_tp_size_mla(
...@@ -863,7 +865,7 @@ class TestNixlHandshake: ...@@ -863,7 +865,7 @@ class TestNixlHandshake:
assert req_id not in conn_p1.connector_worker._reqs_to_process assert req_id not in conn_p1.connector_worker._reqs_to_process
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_concurrent_load_kv( def test_concurrent_load_kv(
...@@ -928,7 +930,7 @@ class TestNixlHandshake: ...@@ -928,7 +930,7 @@ class TestNixlHandshake:
raise TimeoutError("Took too long to complete async handshake.") raise TimeoutError("Took too long to complete async handshake.")
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_handshake_fails_on_kv_cache_layout_mismatch( def test_handshake_fails_on_kv_cache_layout_mismatch(
...@@ -943,7 +945,7 @@ class TestNixlHandshake: ...@@ -943,7 +945,7 @@ class TestNixlHandshake:
# Mock TP world size to 2 to force heterogeneous TP when # Mock TP world size to 2 to force heterogeneous TP when
# remote_tp_size=1 # remote_tp_size=1
with patch( with patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501 "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.get_tensor_model_parallel_world_size", # noqa: E501
return_value=2, return_value=2,
): ):
# Initialize connector and worker (with fake NIXL wrapper) # Initialize connector and worker (with fake NIXL wrapper)
...@@ -982,7 +984,7 @@ class TestNixlHandshake: ...@@ -982,7 +984,7 @@ class TestNixlHandshake:
worker.add_remote_agent(meta, remote_tp_size=1) worker.add_remote_agent(meta, remote_tp_size=1)
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental( def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
...@@ -997,7 +999,7 @@ class TestNixlHandshake: ...@@ -997,7 +999,7 @@ class TestNixlHandshake:
# Mock TP world size to 2 to force heterogeneous TP when # Mock TP world size to 2 to force heterogeneous TP when
# remote_tp_size=1 # remote_tp_size=1
with patch( with patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501 "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.get_tensor_model_parallel_world_size", # noqa: E501
return_value=2, return_value=2,
): ):
# Initialize connector and worker (with fake NIXL wrapper) # Initialize connector and worker (with fake NIXL wrapper)
...@@ -1042,7 +1044,7 @@ class TestNixlHandshake: ...@@ -1042,7 +1044,7 @@ class TestNixlHandshake:
# we put here is important. First run ray, it will clean up the resources, then # we put here is important. First run ray, it will clean up the resources, then
# the rest of the tests. # the rest of the tests.
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_kv_connector_stats(default_vllm_config, dist_init): def test_kv_connector_stats(default_vllm_config, dist_init):
...@@ -1227,8 +1229,8 @@ def test_multi_kv_connector_stats_aggregation(): ...@@ -1227,8 +1229,8 @@ def test_multi_kv_connector_stats_aggregation():
worker_patterns = [(2, 1), (3, 0), (0, 5)] # (Nixl, Foo) worker_patterns = [(2, 1), (3, 0), (0, 5)] # (Nixl, Foo)
worker_outputs: list[ModelRunnerOutput] = [] worker_outputs: list[ModelRunnerOutput] = []
for i, (nixl, foo) in enumerate(worker_patterns): for i, (nixl_count, foo) in enumerate(worker_patterns):
stats = make_multi_stats(nixl, foo) stats = make_multi_stats(nixl_count, foo)
output = ModelRunnerOutput( output = ModelRunnerOutput(
req_ids=[f"req_{i}"], req_ids=[f"req_{i}"],
req_id_to_index={f"req_{i}": 0}, req_id_to_index={f"req_{i}": 0},
...@@ -1256,7 +1258,7 @@ def test_multi_kv_connector_stats_aggregation(): ...@@ -1256,7 +1258,7 @@ def test_multi_kv_connector_stats_aggregation():
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_scheduler_kv_connector_stats_aggregation(): def test_scheduler_kv_connector_stats_aggregation():
...@@ -1324,7 +1326,7 @@ def test_scheduler_kv_connector_stats_aggregation(): ...@@ -1324,7 +1326,7 @@ def test_scheduler_kv_connector_stats_aggregation():
@pytest.mark.parametrize("distributed_executor_backend", ["ray", None]) @pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend): def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
...@@ -1513,13 +1515,14 @@ def test_register_kv_caches( ...@@ -1513,13 +1515,14 @@ def test_register_kv_caches(
backend_cls = TritonAttentionBackend backend_cls = TritonAttentionBackend
nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector" nixl_worker = "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker"
nixl_connector = "vllm.distributed.kv_transfer.kv_connector.v1.nixl.connector"
with ( with (
patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper, patch(f"{nixl_worker}.NixlWrapper") as mock_nixl_wrapper,
patch(f"{nixl_module}.threading.Event"), patch(f"{nixl_worker}.threading.Event"),
patch(f"{nixl_module}.threading.Thread") as mock_thread, patch(f"{nixl_worker}.threading.Thread") as mock_thread,
patch(f"{nixl_module}.get_current_attn_backend") as mock_get_attn_backend, patch(f"{nixl_connector}.get_current_attn_backend") as mock_get_attn_backend,
patch(f"{nixl_module}.get_current_attn_backends") as mock_get_attn_backends, patch(f"{nixl_worker}.get_current_attn_backends") as mock_get_attn_backends,
): ):
# Ensure get_attn_backend returns the correct value due to # Ensure get_attn_backend returns the correct value due to
# _cached_get_attn_backend returning the backend from previous # _cached_get_attn_backend returning the backend from previous
...@@ -1754,28 +1757,26 @@ def test_kv_buffer_to_nixl_memory_types( ...@@ -1754,28 +1757,26 @@ def test_kv_buffer_to_nixl_memory_types(
vllm_config = create_vllm_config() vllm_config = create_vllm_config()
# Override the default memory types in the config # Override the default memory types in the config
vllm_config.kv_transfer_config.kv_buffer_device = kv_buffer_device vllm_config.kv_transfer_config.kv_buffer_device = kv_buffer_device
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl.utils import (
_NIXL_SUPPORTED_DEVICE, _NIXL_SUPPORTED_DEVICE,
) )
_NIXL_SUPPORTED_DEVICE.update(FakePlatform.get_nixl_supported_devices()) _NIXL_SUPPORTED_DEVICE.update(FakePlatform.get_nixl_supported_devices())
with ( with (
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper"),
patch( patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper" "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.threading.Event"
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"
), ),
patch( patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread" "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.threading.Thread"
), ),
patch( patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.current_platform",
FakePlatform, FakePlatform,
), ),
patch( patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector._NIXL_SUPPORTED_DEVICE", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.utils._NIXL_SUPPORTED_DEVICE",
_NIXL_SUPPORTED_DEVICE, _NIXL_SUPPORTED_DEVICE,
), ),
): # noqa: E501 ): # noqa: E501
...@@ -1790,7 +1791,7 @@ def test_kv_buffer_to_nixl_memory_types( ...@@ -1790,7 +1791,7 @@ def test_kv_buffer_to_nixl_memory_types(
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_shutdown_cleans_up_resources(default_vllm_config, dist_init): def test_shutdown_cleans_up_resources(default_vllm_config, dist_init):
...@@ -1855,7 +1856,7 @@ def test_shutdown_cleans_up_resources(default_vllm_config, dist_init): ...@@ -1855,7 +1856,7 @@ def test_shutdown_cleans_up_resources(default_vllm_config, dist_init):
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_init): def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_init):
...@@ -1975,7 +1976,7 @@ class FailingNixlWrapper(FakeNixlWrapper): ...@@ -1975,7 +1976,7 @@ class FailingNixlWrapper(FakeNixlWrapper):
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FailingNixlWrapper, FailingNixlWrapper,
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -2065,10 +2066,10 @@ def test_transfer_failure_logging( ...@@ -2065,10 +2066,10 @@ def test_transfer_failure_logging(
slot_mapping={}, slot_mapping={},
) )
# Capture logs from the nixl_connector logger specifically # Capture logs from the nixl.worker logger specifically
# vLLM loggers have propagate=False, so we need to capture directly # vLLM loggers have propagate=False, so we need to capture directly
nixl_logger = logging.getLogger( nixl_logger = logging.getLogger(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector" "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker"
) )
captured_logs: list[logging.LogRecord] = [] captured_logs: list[logging.LogRecord] = []
...@@ -2130,7 +2131,7 @@ def test_transfer_failure_logging( ...@@ -2130,7 +2131,7 @@ def test_transfer_failure_logging(
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FailingNixlWrapper, FailingNixlWrapper,
) )
def test_handshake_failure_returns_finished(default_vllm_config, dist_init): def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
...@@ -2181,7 +2182,7 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init): ...@@ -2181,7 +2182,7 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FailingNixlWrapper, FailingNixlWrapper,
) )
def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init): def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init):
...@@ -2257,7 +2258,7 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init) ...@@ -2257,7 +2258,7 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
], ],
) )
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_compatibility_hash_validation( def test_compatibility_hash_validation(
...@@ -2328,7 +2329,7 @@ def test_compatibility_hash_validation( ...@@ -2328,7 +2329,7 @@ def test_compatibility_hash_validation(
elif "connector_version" in version_override: elif "connector_version" in version_override:
stack.enter_context( stack.enter_context(
patch.object( patch.object(
nixl_connector, nixl.metadata,
"NIXL_CONNECTOR_VERSION", "NIXL_CONNECTOR_VERSION",
version_override["connector_version"], version_override["connector_version"],
) )
...@@ -2365,7 +2366,7 @@ def test_compatibility_hash_validation( ...@@ -2365,7 +2366,7 @@ def test_compatibility_hash_validation(
# Patch zmq_ctx to return our mock socket # Patch zmq_ctx to return our mock socket
with ( with (
patch.object(decode_worker, "add_remote_agent", return_value="fake_agent"), patch.object(decode_worker, "add_remote_agent", return_value="fake_agent"),
patch.object(nixl_connector, "zmq_ctx") as mock_zmq_ctx, patch.object(nixl.worker, "zmq_ctx") as mock_zmq_ctx,
): ):
mock_zmq_ctx.return_value.__enter__.return_value = mock_socket mock_zmq_ctx.return_value.__enter__.return_value = mock_socket
...@@ -2399,7 +2400,7 @@ def test_compatibility_hash_validation( ...@@ -2399,7 +2400,7 @@ def test_compatibility_hash_validation(
], ],
) )
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
) )
def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario): def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario):
...@@ -2464,7 +2465,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) ...@@ -2464,7 +2465,7 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
mock_socket.recv.return_value = msg_bytes mock_socket.recv.return_value = msg_bytes
with ( with (
patch.object(decode_worker, "add_remote_agent", return_value="fake_agent"), patch.object(decode_worker, "add_remote_agent", return_value="fake_agent"),
patch.object(nixl_connector, "zmq_ctx") as mock_zmq_ctx, patch.object(nixl.worker, "zmq_ctx") as mock_zmq_ctx,
): ):
mock_zmq_ctx.return_value.__enter__.return_value = mock_socket mock_zmq_ctx.return_value.__enter__.return_value = mock_socket
......
...@@ -31,10 +31,10 @@ from .utils import ( ...@@ -31,10 +31,10 @@ from .utils import (
(False, [0]), (False, [0]),
], ],
) )
@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform") @patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler.current_platform")
def test_sw_sizes(mock_platform, swa_enabled, expected_sw_sizes): def test_sw_sizes(mock_platform, swa_enabled, expected_sw_sizes):
"""Test sw_sizes is correctly computed based on SWA enabled/disabled.""" """Test sw_sizes is correctly computed based on SWA enabled/disabled."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler import (
NixlConnectorScheduler, NixlConnectorScheduler,
) )
...@@ -65,7 +65,7 @@ def test_logical_to_kernel_block_ids_with_hma(): ...@@ -65,7 +65,7 @@ def test_logical_to_kernel_block_ids_with_hma():
When HMA is enabled, the logical block size may differ from the kernel When HMA is enabled, the logical block size may differ from the kernel
block size. Each logical block maps to multiple kernel blocks. block size. Each logical block maps to multiple kernel blocks.
""" """
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
NixlConnectorWorker, NixlConnectorWorker,
) )
...@@ -169,7 +169,7 @@ def test_nixl_metadata_hma_block_ids_structure(): ...@@ -169,7 +169,7 @@ def test_nixl_metadata_hma_block_ids_structure():
Test that NixlConnectorMetadata correctly stores block IDs for multiple Test that NixlConnectorMetadata correctly stores block IDs for multiple
KV cache groups when HMA is enabled. KV cache groups when HMA is enabled.
""" """
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import (
NixlConnectorMetadata, NixlConnectorMetadata,
) )
...@@ -211,7 +211,7 @@ def test_nixl_metadata_hma_block_ids_structure(): ...@@ -211,7 +211,7 @@ def test_nixl_metadata_hma_block_ids_structure():
def test_get_block_descs_ids_hybrid_ssm(): def test_get_block_descs_ids_hybrid_ssm():
"""Test _get_block_descs_ids uses per-group strides for hybrid FA+SSM """Test _get_block_descs_ids uses per-group strides for hybrid FA+SSM
when ratio=1 (no kernel block size mismatch).""" when ratio=1 (no kernel block size mismatch)."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
NixlConnectorWorker, NixlConnectorWorker,
) )
...@@ -247,7 +247,7 @@ def test_get_block_descs_ids_hybrid_ssm(): ...@@ -247,7 +247,7 @@ def test_get_block_descs_ids_hybrid_ssm():
def test_get_block_descs_ids_kernel_block_mismatch(): def test_get_block_descs_ids_kernel_block_mismatch():
"""Test _get_block_descs_ids uses different strides for FA (kernel blocks) """Test _get_block_descs_ids uses different strides for FA (kernel blocks)
vs SSM (logical blocks) when ratio > 1.""" vs SSM (logical blocks) when ratio > 1."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
NixlConnectorWorker, NixlConnectorWorker,
) )
...@@ -284,7 +284,7 @@ def test_get_block_descs_ids_kernel_block_mismatch(): ...@@ -284,7 +284,7 @@ def test_get_block_descs_ids_kernel_block_mismatch():
def test_nixl_metadata_hybrid_ssm_block_ids(): def test_nixl_metadata_hybrid_ssm_block_ids():
"""Test NixlConnectorMetadata correctly stores block IDs for FA + SSM """Test NixlConnectorMetadata correctly stores block IDs for FA + SSM
groups with different block counts (kernel mismatch active).""" groups with different block counts (kernel mismatch active)."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import (
NixlConnectorMetadata, NixlConnectorMetadata,
) )
...@@ -392,7 +392,7 @@ def test_mamba_n1_p_side_truncation(): ...@@ -392,7 +392,7 @@ def test_mamba_n1_p_side_truncation():
], ],
ids=["fa_swa_mamba", "fa_swa_only", "fa_only"], ids=["fa_swa_mamba", "fa_swa_only", "fa_only"],
) )
@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform") @patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler.current_platform")
def test_has_mamba_init( def test_has_mamba_init(
mock_platform, mock_platform,
swa_enabled, swa_enabled,
...@@ -401,7 +401,7 @@ def test_has_mamba_init( ...@@ -401,7 +401,7 @@ def test_has_mamba_init(
expected_is_hma, expected_is_hma,
): ):
"""Test _has_mamba / _is_hma_required derived from kv_cache_groups.""" """Test _has_mamba / _is_hma_required derived from kv_cache_groups."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler import (
NixlConnectorScheduler, NixlConnectorScheduler,
) )
......
...@@ -587,7 +587,7 @@ def test_cannot_recv(): ...@@ -587,7 +587,7 @@ def test_cannot_recv():
assert_scheduler_empty(scheduler) assert_scheduler_empty(scheduler)
@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform") @patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler.current_platform")
def test_p_side_chunked_prefill_mamba(mock_platform): def test_p_side_chunked_prefill_mamba(mock_platform):
"""P-side integration: Mamba N-1 truncation + chunked prefill completes. """P-side integration: Mamba N-1 truncation + chunked prefill completes.
......
...@@ -476,7 +476,7 @@ def make_nixl_scheduler(has_mamba: bool = False, is_hma_required: bool = False): ...@@ -476,7 +476,7 @@ def make_nixl_scheduler(has_mamba: bool = False, is_hma_required: bool = False):
Only sets the two flags needed by the N-1 prefill logic. Only sets the two flags needed by the N-1 prefill logic.
""" """
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler import (
NixlConnectorScheduler, NixlConnectorScheduler,
) )
......
...@@ -178,7 +178,7 @@ KVConnectorFactory.register_connector( ...@@ -178,7 +178,7 @@ KVConnectorFactory.register_connector(
KVConnectorFactory.register_connector( KVConnectorFactory.register_connector(
"NixlConnector", "NixlConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector", "vllm.distributed.kv_transfer.kv_connector.v1.nixl",
"NixlConnector", "NixlConnector",
) )
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""NIXL KV-cache transfer connector (disaggregated prefill / decode)."""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.connector import (
NixlConnector,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import (
NixlAgentMetadata,
NixlConnectorMetadata,
NixlHandshakePayload,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler import (
NixlConnectorScheduler,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.stats import (
NixlKVConnectorStats,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
NixlConnectorWorker,
)
__all__ = [
"NixlAgentMetadata",
"NixlConnector",
"NixlConnectorMetadata",
"NixlConnectorScheduler",
"NixlConnectorWorker",
"NixlHandshakePayload",
"NixlKVConnectorStats",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""NixlConnector – thin facade that delegates to scheduler / worker."""
from typing import TYPE_CHECKING, Any
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import (
EngineId,
get_current_attn_backend,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp,
KVConnectorBase_V1,
KVConnectorHandshakeMetadata,
KVConnectorMetadata,
KVConnectorRole,
SupportsHMA,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.metadata import (
NixlConnectorMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.scheduler import (
NixlConnectorScheduler,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.stats import (
NixlKVConnectorStats,
NixlPromMetrics,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
NixlConnectorWorker,
)
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import MambaSpec
if TYPE_CHECKING:
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
class NixlConnector(KVConnectorBase_V1, SupportsHMA):
@property
def prefer_cross_layer_blocks(self) -> bool:
if any(
[
isinstance(group.kv_cache_spec, MambaSpec)
for group in self.kv_cache_config.kv_cache_groups
]
):
# Hybrid SSM models do not yet support cross-layer layout
return False
backend = get_current_attn_backend(self._vllm_config)
if backend.get_name() not in (
"FLASH_ATTN",
"FLASHINFER",
"TRITON_ATTN",
):
return False
# For now there is no benefit to run cross layers when backend
# does not support on HND
if get_kv_cache_layout() != "HND":
return False
extra_config = self.kv_transfer_config.kv_connector_extra_config
return (
str(extra_config.get("enable_cross_layers_blocks", "False")).lower()
== "true"
)
def __init__(
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig",
):
super().__init__(vllm_config, role, kv_cache_config)
assert vllm_config.kv_transfer_config is not None
assert vllm_config.kv_transfer_config.engine_id is not None
self.kv_cache_config = kv_cache_config
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
self.kv_transfer_config = vllm_config.kv_transfer_config
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler: NixlConnectorScheduler | None = (
NixlConnectorScheduler(vllm_config, self.engine_id, kv_cache_config)
)
self.connector_worker: NixlConnectorWorker | None = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = NixlConnectorWorker(
vllm_config, self.engine_id, kv_cache_config
)
############################################################
# Class Methods
############################################################
@classmethod
def get_required_kvcache_layout(cls, vllm_config: VllmConfig):
if vllm_config.model_config is None:
logger.warning_once(
"Unable to detect current VLLM config. "
"Fallback to default kv cache layout."
)
return None
use_mla = vllm_config.model_config.use_mla
if use_mla:
# return None when we have mla
# as the layout should not matter in that case,
# which fallback to the default behavior.
return None
logger.info_once(
"NixlConnector setting KV cache layout to HND for better xfer performance."
)
return "HND"
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int | None, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens
)
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens
)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, (block_ids,))
def request_finished_all_groups(
self,
request: "Request",
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
"""
Set the KV connector handshake metadata for this connector.
Args:
metadata (dict): the handshake metadata to set.
"""
assert self.connector_scheduler is not None
self.connector_scheduler.set_xfer_handshake_metadata(metadata)
############################################################
# Worker Side Methods
############################################################
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
assert self.connector_worker is not None
self.connector_worker.register_cross_layers_kv_caches(kv_cache)
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
assert self.connector_worker is not None
self.connector_worker.set_host_xfer_buffer_ops(copy_operation)
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
return self.connector_worker.get_finished()
def get_block_ids_with_load_errors(self) -> set[int]:
"""Get block IDs that failed to load via NIXL."""
assert self.connector_worker is not None
return self.connector_worker.get_block_ids_with_load_errors()
def get_kv_connector_stats(self) -> KVConnectorStats | None:
if self.connector_worker is None:
return None
return self.connector_worker.get_kv_connector_stats()
@classmethod
def build_kv_connector_stats(
cls, data: dict[str, Any] | None = None
) -> KVConnectorStats | None:
return (
NixlKVConnectorStats(data=data)
if data is not None
else NixlKVConnectorStats()
)
@classmethod
def build_prom_metrics(
cls,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]],
) -> KVConnectorPromMetrics:
return NixlPromMetrics(
vllm_config, metric_types, labelnames, per_engine_labelvalues
)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, NixlConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata)
def wait_for_layer_load(self, layer_name: str) -> None:
"""NixlConnector does not do layerwise saving."""
pass
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
"""NixlConnector does not save explicitly."""
pass
def wait_for_save(self):
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, NixlConnectorMetadata)
if self.connector_worker.use_host_buffer and self.connector_worker.copy_blocks:
self.connector_worker.save_kv_to_host(self._connector_metadata)
def shutdown(self):
if self.connector_worker is not None:
self.connector_worker.shutdown()
if self.connector_scheduler is not None:
self.connector_scheduler.shutdown()
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
"""
Get the KVConnector handshake metadata for this connector.
This metadata is used for out-of-band connector handshake
between P/D workers.
Returns:
KVConnectorHandshakeMetadata: the handshake metadata.
None if no handshake metadata is available.
"""
assert self.connector_worker is not None
return self.connector_worker.xfer_handshake_metadata
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Metadata dataclasses and helpers for the NIXL connector."""
from dataclasses import dataclass
from typing import Any
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import BlockIds
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorHandshakeMetadata,
KVConnectorMetadata,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
TransferHandle = int
ReqId = str
GET_META_MSG = b"get_meta_msg"
#
# NIXL Connector Version
#
# Increment this version whenever there is an incompatible change to:
# - NixlAgentMetadata schema
# - kv_transfer_params schema or semantics
# - NIXL transfer protocol or wire format
# - KV cache memory layout or block organization
# - Any other change that breaks P/D interoperability
#
# Version History:
# 1: Initial version with compatibility checking
# 2: Add remote_request_id to kv_transfer_params
#
NIXL_CONNECTOR_VERSION: int = 2
@dataclass
class NixlAgentMetadata:
engine_id: str
agent_metadata: bytes
kv_caches_base_addr: list[int]
device_id: int
num_blocks: int
block_lens: list[int]
kv_cache_layout: str
block_size: int
ssm_sizes: tuple[int, int]
attn_backend_name: str
@dataclass
class NixlHandshakePayload(KVConnectorHandshakeMetadata):
"""
Wrapper for NIXL handshake sent over the wire.
Enables two-phase decoding for graceful compatibility checking:
1. Decode NixlHandshakePayload to get compatibility_hash
2. Compute local hash and compare
3. Only if hashes match, decode agent_metadata_bytes
This prevents decoder errors when NixlAgentMetadata schema is
incompatible, allowing graceful failure with clear error message.
"""
compatibility_hash: str
agent_metadata_bytes: bytes # NixlAgentMetadata encoded
def compute_nixl_compatibility_hash(
vllm_config: VllmConfig, attn_backend_name: str, cross_layers_blocks: bool
) -> str:
"""
Compute compatibility hash for NIXL KV transfer.
Hash only the factors that affect whether two NIXL instances can
successfully transfer KV cache data.
Factors included:
- vLLM version and NIXL connector version
- Model architecture (name, dtype, KV heads, layers)
- KV cache format (dtype, sliding window)
- Attention backend
Note: Factors like tensor_parallel_size, block_size, and kv_cache_layout
are validated at runtime in _validate_remote_agent_handshake and are not
included in this hash to support heterogeneous deployments.
Note - the set of factors are likely to evolve significantly over
time to be more or less permissive.
Returns:
SHA-256 hex digest
"""
from vllm import __version__ as vllm_version
from vllm.config.utils import hash_factors
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
is_hma_enabled = not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager
factors = {
# Version compatibility
"vllm_version": vllm_version,
"nixl_connector_version": NIXL_CONNECTOR_VERSION,
# Model architecture - affects KV cache shape
"model": model_config.model,
"dtype": str(model_config.dtype),
"num_kv_heads": model_config.get_total_num_kv_heads(),
"head_size": model_config.get_head_size(),
"num_hidden_layers": model_config.get_total_num_hidden_layers(),
# Attention backend and KV cache dtype affect memory layout
"attn_backend_name": attn_backend_name,
"cache_dtype": str(cache_config.cache_dtype),
"cross_layers_blocks": cross_layers_blocks,
"is_hma_enabled": is_hma_enabled,
}
compat_hash = hash_factors(factors)
logger.debug(
"NIXL compatibility hash: %s (model=%s, dtype=%s, num_kv_heads=%d, "
"cache_dtype=%s, attn_backend=%s)",
compat_hash,
factors["model"],
factors["dtype"],
factors["num_kv_heads"],
factors["cache_dtype"],
attn_backend_name,
)
return compat_hash
@dataclass
class RemoteMeta:
block_ids: BlockIds
host: str
port: int
engine_id: str
request_id: str
@dataclass
class ReqMeta:
local_block_ids: BlockIds
# To be used when logical block size does not match the kernel block size
local_physical_block_ids: BlockIds
tp_size: int
remote: RemoteMeta | None = None
class NixlConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
self.reqs_to_send: dict[ReqId, float] = {}
self.reqs_in_batch: set[ReqId] = set()
self.reqs_not_processed: set[ReqId] = set()
def _add_new_req(
self,
local_block_ids: BlockIds,
kv_transfer_params: dict[str, Any],
) -> ReqMeta:
return ReqMeta(
local_block_ids=local_block_ids,
local_physical_block_ids=local_block_ids,
# P workers don't need to receive tp_size from proxy here.
tp_size=kv_transfer_params.get("tp_size", 1),
)
def add_new_req_to_save(
self,
request_id: ReqId,
local_block_ids: BlockIds,
kv_transfer_params: dict[str, Any],
):
self.reqs_to_save[request_id] = self._add_new_req(
local_block_ids, kv_transfer_params
)
def add_new_req_to_recv(
self,
request_id: ReqId,
local_block_ids: BlockIds,
kv_transfer_params: dict[str, Any],
):
req = self._add_new_req(local_block_ids, kv_transfer_params)
req.remote = RemoteMeta(
block_ids=kv_transfer_params["remote_block_ids"],
engine_id=kv_transfer_params["remote_engine_id"],
request_id=kv_transfer_params["remote_request_id"],
host=kv_transfer_params["remote_host"],
port=kv_transfer_params["remote_port"],
)
self.reqs_to_recv[request_id] = req
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Stats and Prometheus metrics for the NIXL connector."""
import copy
from dataclasses import dataclass
from typing import Any
import numpy as np
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.utils import (
nixlXferTelemetry,
)
from vllm.v1.metrics.utils import create_metric_per_engine
@dataclass
class NixlKVConnectorStats(KVConnectorStats):
"""Container for transfer performance metrics"""
def __post_init__(self):
if not self.data:
# Empty container init, no data is passed in.
self.reset()
def reset(self):
# Must be serializable
self.data: dict[str, list[float | int]] = {
"transfer_duration": [],
"post_duration": [],
"bytes_transferred": [],
"num_descriptors": [],
"num_failed_transfers": [],
"num_failed_notifications": [],
"num_kv_expired_reqs": [],
}
def record_transfer(self, res: nixlXferTelemetry):
# Keep metrics units consistent with rest of the code: time us->s
self.data["transfer_duration"].append(res.xferDuration / 1e6)
self.data["post_duration"].append(res.postDuration / 1e6)
self.data["bytes_transferred"].append(res.totalBytes)
self.data["num_descriptors"].append(res.descCount)
def record_failed_transfer(self):
"""Record a failed NIXL transfer operation."""
self.data["num_failed_transfers"].append(1)
def record_failed_notification(self):
"""Record a failed NIXL notification (send_notif)."""
self.data["num_failed_notifications"].append(1)
def record_kv_expired_req(self):
"""Record a request that had its KV blocks expire."""
self.data["num_kv_expired_reqs"].append(1)
def clone_and_reset(self) -> "NixlKVConnectorStats":
old = copy.copy(self)
self.reset()
return old
def is_empty(self) -> bool:
# Do not discard metrics update that are entirely failures related.
return (
self.num_successful_transfers == 0
and len(self.data["num_failed_transfers"]) == 0
and len(self.data["num_failed_notifications"]) == 0
and len(self.data["num_kv_expired_reqs"]) == 0
)
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
if not other.is_empty():
for k, v in other.data.items():
accumulator = self.data[k]
assert isinstance(accumulator, list)
accumulator.extend(v)
return self
def reduce(self) -> dict[str, int | float]:
# Compute compact representative stats suitable for CLI logging
if self.num_successful_transfers == 0:
# CLI logging only reports successful transfers stats. If all requests in
# the interval were unsuccessful, Prom will report failures stats instead.
return {
"Num successful transfers": 0,
"Avg xfer time (ms)": 0,
"P90 xfer time (ms)": 0,
"Avg post time (ms)": 0,
"P90 post time (ms)": 0,
"Avg MB per transfer": 0,
"Throughput (MB/s)": 0,
"Avg number of descriptors": 0,
}
xfer_time = np.asarray(self.data["transfer_duration"])
post_time = np.asarray(self.data["post_duration"])
# Convert to MB for CLI logging.
mb = np.asarray(self.data["bytes_transferred"]) / 2**20
descs = np.asarray(self.data["num_descriptors"], dtype=np.uint32)
n = len(descs)
assert n == self.num_successful_transfers
total_mb = mb.sum()
avg_mb = total_mb / n
total_time_seconds = xfer_time.sum()
throughput_mb_s = total_mb / total_time_seconds
return {
"Num successful transfers": n,
"Avg xfer time (ms)": round(xfer_time.mean() * 1e3, 3),
"P90 xfer time (ms)": round(np.percentile(xfer_time, 90).item() * 1e3, 3),
"Avg post time (ms)": round(post_time.mean() * 1e3, 3),
"P90 post time (ms)": round(np.percentile(post_time, 90).item() * 1e3, 3),
"Avg MB per transfer": round(avg_mb, 3),
"Throughput (MB/s)": round(throughput_mb_s, 3),
"Avg number of descriptors": round(descs.mean(), 1),
}
@property
def num_successful_transfers(self) -> int:
return len(self.data["transfer_duration"])
class NixlPromMetrics(KVConnectorPromMetrics):
def __init__(
self,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]],
):
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
buckets = [
0.001,
0.005,
0.01,
0.025,
0.05,
0.075,
0.1,
0.2,
0.3,
0.5,
0.75,
1.0,
5.0,
]
nixl_histogram_xfer_time = self._histogram_cls(
name="vllm:nixl_xfer_time_seconds",
documentation="Histogram of transfer duration for NIXL KV Cache transfers.",
buckets=buckets[1:],
labelnames=labelnames,
)
self.nixl_histogram_xfer_time = create_metric_per_engine(
nixl_histogram_xfer_time, self.per_engine_labelvalues
)
nixl_histogram_post_time = self._histogram_cls(
name="vllm:nixl_post_time_seconds",
documentation="Histogram of transfer post time for NIXL KV"
" Cache transfers.",
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_post_time = create_metric_per_engine(
nixl_histogram_post_time, self.per_engine_labelvalues
)
# uniform 2kb to 16gb range
buckets = [2 ** (10 + i) for i in range(1, 25, 2)]
nixl_histogram_bytes_transferred = self._histogram_cls(
name="vllm:nixl_bytes_transferred",
documentation="Histogram of bytes transferred per NIXL KV Cache transfers.",
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_bytes_transferred = create_metric_per_engine(
nixl_histogram_bytes_transferred, self.per_engine_labelvalues
)
buckets = [
10,
20,
30,
50,
75,
100,
200,
400,
1000,
2000,
4000,
10000,
20000,
50000,
]
nixl_histogram_num_descriptors = self._histogram_cls(
name="vllm:nixl_num_descriptors",
documentation="Histogram of number of descriptors per NIXL"
" KV Cache transfers.",
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_num_descriptors = create_metric_per_engine(
nixl_histogram_num_descriptors, self.per_engine_labelvalues
)
counter_nixl_num_failed_transfers = self._counter_cls(
name="vllm:nixl_num_failed_transfers",
documentation="Number of failed NIXL KV Cache transfers.",
labelnames=labelnames,
)
self.counter_nixl_num_failed_transfers = create_metric_per_engine(
counter_nixl_num_failed_transfers, self.per_engine_labelvalues
)
counter_nixl_num_failed_notifications = self._counter_cls(
name="vllm:nixl_num_failed_notifications",
documentation="Number of failed NIXL KV Cache notifications.",
labelnames=labelnames,
)
self.counter_nixl_num_failed_notifications = create_metric_per_engine(
counter_nixl_num_failed_notifications, self.per_engine_labelvalues
)
counter_nixl_num_kv_expired_reqs = self._counter_cls(
name="vllm:nixl_num_kv_expired_reqs",
documentation="Number of requests that had their KV expire. "
"NOTE: This metric is tracked on the P instance.",
labelnames=labelnames,
)
self.counter_nixl_num_kv_expired_reqs = create_metric_per_engine(
counter_nixl_num_kv_expired_reqs, self.per_engine_labelvalues
)
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
for prom_obj, list_item_key in zip(
[
self.nixl_histogram_xfer_time,
self.nixl_histogram_post_time,
self.nixl_histogram_bytes_transferred,
self.nixl_histogram_num_descriptors,
],
[
"transfer_duration",
"post_duration",
"bytes_transferred",
"num_descriptors",
],
):
for list_item in transfer_stats_data[list_item_key]:
prom_obj[engine_idx].observe(list_item)
for counter_obj, counter_item_key in zip(
[
self.counter_nixl_num_failed_transfers,
self.counter_nixl_num_failed_notifications,
self.counter_nixl_num_kv_expired_reqs,
],
["num_failed_transfers", "num_failed_notifications", "num_kv_expired_reqs"],
):
for list_item in transfer_stats_data[counter_item_key]:
counter_obj[engine_idx].inc(list_item)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Shared constants, lazy imports and helpers for the NIXL connector."""
import contextlib
import os
import sys
from collections.abc import Iterator
from typing import Any
import zmq
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.network_utils import make_zmq_socket
logger = init_logger(__name__)
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
try:
if "UCX_RCACHE_MAX_UNRELEASED" not in os.environ:
# avoid a memory leak in UCX when using NIXL on some models
# see: https://github.com/vllm-project/vllm/issues/24264
if "nixl" in sys.modules or "rixl" in sys.modules:
logger.warning(
"NIXL was already imported, we can't reset UCX_RCACHE_MAX_UNRELEASED. "
"Please set it to '1024' manually."
)
else:
logger.info(
"Setting UCX_RCACHE_MAX_UNRELEASED to '1024' to avoid a rare "
"memory leak in UCX when using NIXL."
)
os.environ["UCX_RCACHE_MAX_UNRELEASED"] = "1024"
if not current_platform.is_rocm():
from nixl._api import nixl_agent as NixlWrapper
from nixl._bindings import nixlXferTelemetry
else:
from rixl._api import nixl_agent as NixlWrapper
from rixl._bindings import nixlXferTelemetry
logger.info("NIXL is available")
except ImportError:
logger.warning("NIXL is not available")
NixlWrapper = None
nixlXferTelemetry = None
try:
if not current_platform.is_rocm():
from nixl._api import nixl_agent_config
else:
from rixl._api import nixl_agent_config
except ImportError:
nixl_agent_config = None
logger.warning("NIXL agent config is not available")
# Supported platforms and types of kv transfer buffer.
# {device: tuple of supported kv buffer types}
_NIXL_SUPPORTED_DEVICE = {
"cuda": (
"cuda",
"cpu",
),
"tpu": ("cpu",),
"xpu": (
"cpu",
"xpu",
),
"cpu": ("cpu",),
}
# support for oot platform by providing mapping in current_platform
_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
# TODO: merge with vllm.utils.network_utils.zmq_socket_ctx
@contextlib.contextmanager
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket"""
if socket_type not in (zmq.ROUTER, zmq.REQ):
raise ValueError(f"Unexpected socket type: {socket_type}")
ctx: zmq.Context | None = None
try:
ctx = zmq.Context() # type: ignore[attr-defined]
yield make_zmq_socket(
ctx=ctx, path=addr, socket_type=socket_type, bind=socket_type == zmq.ROUTER
)
finally:
if ctx is not None:
ctx.destroy(linger=0)
...@@ -114,7 +114,7 @@ def derive_mamba_conv_split( ...@@ -114,7 +114,7 @@ def derive_mamba_conv_split(
assert len(conv_shape) == 2, f"Expected 2D conv state shape, got {conv_shape}" assert len(conv_shape) == 2, f"Expected 2D conv state shape, got {conv_shape}"
# NOTE (ZhanqiuHu): 3-read requires DS layout, which is already asserted # NOTE (ZhanqiuHu): 3-read requires DS layout, which is already asserted
# in nixl_connector __init__. Use it directly instead of heuristic detection. # in nixl worker __init__. Use it directly instead of heuristic detection.
assert is_conv_state_dim_first(), "3-read requires DS conv state layout" assert is_conv_state_dim_first(), "3-read requires DS conv state layout"
local_conv_dim = conv_shape[0] # DS: (conv_dim_local, conv_rows) local_conv_dim = conv_shape[0] # DS: (conv_dim_local, conv_rows)
conv_rows = conv_shape[1] conv_rows = conv_shape[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment